Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reshape able slice #1241

Merged
merged 22 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7bbd861
Added Caffe Slice_ext
pavel-esir Jul 6, 2020
be9b281
Added TFSlice, AttributedSlice (both with extractors and replacers), …
pavel-esir Jul 7, 2020
8623741
added comments to each type of Slice operation; optimized shape infer…
pavel-esir Jul 13, 2020
09de059
removed type annotation for get_shape_after_slice routine
pavel-esir Jul 13, 2020
052833f
replaced zeros_like with zeros
pavel-esir Jul 13, 2020
287216c
Corrected preserving node names, renamed attributes names, added test…
pavel-esir Jul 17, 2020
683d2c3
Renamed slice_replacers.py
pavel-esir Jul 20, 2020
876801b
added more unittest cases
pavel-esir Jul 20, 2020
d9606a5
added type annotations, moved to more relevant place routines for sha…
pavel-esir Jul 31, 2020
e2d91ef
corrected a typo `normalize_slice_indices` comment
pavel-esir Jul 31, 2020
026750d
corrected shape calculation for Nonconstant inputs
pavel-esir Jul 31, 2020
89cb9f3
corrected a few typos
pavel-esir Aug 2, 2020
ce79359
corrected type declarations
pavel-esir Aug 3, 2020
724b44e
corrected shape inference with rounding
pavel-esir Aug 5, 2020
a28d5bc
refactored unit-tests for front transforms of Slice
pavel-esir Aug 5, 2020
4335670
added error raising for negative and zero shapes
pavel-esir Aug 5, 2020
d9c9847
removed magic_num
pavel-esir Aug 6, 2020
57df332
corrected AttributedSlice, clarified comments
pavel-esir Aug 6, 2020
ab6b005
fixed unit-test for AttributedSliceToSlice
pavel-esir Aug 6, 2020
9c23fab
typo in type hints corrected
pavel-esir Aug 7, 2020
0278a7c
removed supported_attrs
pavel-esir Aug 7, 2020
831be1d
returned back default None for attrs of Slice
pavel-esir Aug 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ extensions/front/caffe/reshape.py
extensions/front/caffe/shufflechannel_ext.py
extensions/front/caffe/sigmoid.py
extensions/front/caffe/simplernms_ext.py
extensions/front/caffe/slice_ext.py
extensions/front/caffe/slice_to_split.py
extensions/front/caffe/softmax_ext.py
extensions/front/caffe/spatial_transformer_ext.py
Expand Down Expand Up @@ -231,6 +232,7 @@ extensions/front/onnx/activation_ext.py
extensions/front/onnx/affine_ext.py
extensions/front/onnx/argmax_ext.py
extensions/front/onnx/aten_ext.py
extensions/front/onnx/AttributedSliceToSlice.py
extensions/front/onnx/cast_ext.py
extensions/front/onnx/clip_ext.py
extensions/front/onnx/const_ext.py
Expand Down Expand Up @@ -447,6 +449,7 @@ extensions/front/tf/SwitchMergeOptimization.py
extensions/front/tf/TensorArrayExtractors.py
extensions/front/tf/TensorArrayGatherV3.py
extensions/front/tf/tensorflow_custom_operations_config_update.py
extensions/front/tf/TFSliceToSlice.py
extensions/front/tf/tile_ext.py
extensions/front/tf/topk_ext.py
extensions/front/tf/transpose_ext.py
Expand Down Expand Up @@ -624,7 +627,6 @@ extensions/ops/merge.py
extensions/ops/mvn.py
extensions/ops/mxrepeat.py
extensions/ops/mxreshape.py
extensions/ops/mxslice.py
extensions/ops/NextIteration.py
extensions/ops/non_max_suppression.py
extensions/ops/non_zero.py
Expand Down Expand Up @@ -723,7 +725,6 @@ mo/front/caffe/extractors/crop.py
mo/front/caffe/extractors/native_caffe.py
mo/front/caffe/extractors/roipooling.py
mo/front/caffe/extractors/scale.py
mo/front/caffe/extractors/slice.py
mo/front/caffe/extractors/tile.py
mo/front/caffe/extractors/utils.py
mo/front/caffe/loader.py
Expand Down
46 changes: 46 additions & 0 deletions model-optimizer/extensions/front/caffe/slice_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
from mo.ops.slice import CaffeSlice


class SliceFrontExtractor(FrontExtractorOp):
op = 'slice'
enabled = True

@classmethod
def extract(cls, node):
proto_layer = node.pb
param = proto_layer.slice_param

# slice_dim is deprecated parameter and is used as alias for axis
# however if slice_dim is defined and axis is default, we use slice_dim
if param.slice_dim != 1 and param.axis == 1:
axis = param.slice_dim
else:
axis = param.axis

update_attrs = {
'axis': axis,
'slice_point': int64_array(param.slice_point),
'in_ports_count': 1,
'out_ports_count': len(param.slice_point) + 1,
}

CaffeSlice.update_node_stat(node, update_attrs)
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/caffe/slice_to_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class SliceToVariadicSplit(FrontReplacementOp):
op = "Slice"
op = "CaffeSlice"
enabled = True

def replace_sub_graph(self, graph: Graph, match: dict):
Expand All @@ -37,7 +37,7 @@ def replace_sub_graph(self, graph: Graph, match: dict):
return

assert node.has_valid('slice_point'), 'Slice operation `{}` has no `slice_point` parameter'.format(name)
slice_point = np.array(node.slice_point)
slice_point = node.slice_point

if slice_point.size == 0:
num_splits = len(node.out_ports())
Expand Down
2 changes: 1 addition & 1 deletion model-optimizer/extensions/front/mxnet/slice_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import numpy as np

from extensions.ops.mxslice import MXSlice
from mo.front.extractor import FrontExtractorOp
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.ops.slice import MXSlice


class SliceFrontExtractor(FrontExtractorOp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mo.ops.strided_slice import StridedSlice


class SliceFrontReplacer(FrontReplacementOp):
class MXSliceToStridedSliceReplacer(FrontReplacementOp):
op = 'MXSlice'
enabled = True

Expand Down
38 changes: 38 additions & 0 deletions model-optimizer/extensions/front/onnx/AttributedSliceToSlice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.slice import Slice


class AttributedSliceToSliceReplacer(FrontReplacementOp):
"""
This class replaces AttributedSlice -> Slice
"""
op = 'AttributedSlice'
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
enabled = True

def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
slice_name = node.soft_get('name', node.id)

slice_node = create_op_with_const_inputs(graph, Slice, {1: node.starts, 2: node.ends, 3: node.axes})
rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

node.in_port(0).get_connection().set_destination(slice_node.in_port(0))
node.out_port(0).get_connection().set_source(slice_node.out_port(0))
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np
from generator import generator, generate

from extensions.front.onnx.AttributedSliceToSlice import AttributedSliceToSliceReplacer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front


@generator
class SliceReplacerTest(unittest.TestCase):
@generate(*[
{'op': 'AttributedSlice', 'type': None, 'starts': np.array([0, 0]), 'ends': np.array([1, -1]), 'axes': np.array([0, 1])}
])
def test_attributed_slice_replacer(self, attributed_slice_attrs):
nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('attributed_slice', attributed_slice_attrs),
**result(),

# nodes after replacement
**const('start', np.array([0, 0])),
**const('end', np.array([1, -1])),
**const('axis', np.array(np.array([0, 1]))),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
}

graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'attributed_slice'),
('attributed_slice', 'output'),
], nodes_with_edges_only=True)
graph.stage = 'front'

AttributedSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
('input', 'slice'),
*connect_front('start', '1:slice'),
*connect_front('end', '2:slice'),
*connect_front('axis', '3:slice'),
('slice', 'output'),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
32 changes: 18 additions & 14 deletions model-optimizer/extensions/front/onnx/slice_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import get_onnx_opset_version
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.slice import Slice
from mo.ops.slice import Slice, AttributedSlice
from mo.utils.error import Error


class SliceFrontExtractor(FrontExtractorOp):
Expand All @@ -27,17 +30,18 @@ class SliceFrontExtractor(FrontExtractorOp):

@classmethod
def extract(cls, node):
axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]), dtype=np.int64)
start = np.array(onnx_attr(node, 'starts', 'ints', default=[]), dtype=np.int64)
end = np.array(onnx_attr(node, 'ends', 'ints', default=[]), dtype=np.int64)

attrs = {
'axis': axis if len(axis) != 0 else None,
'start': start if len(start) != 0 else None,
'end': end if len(end) != 0 else None,
'format': 'onnx'
}

# update the attributes of the node
Slice.update_node_stat(node, attrs)
if get_onnx_opset_version(node) < 10:
starts = int64_array(onnx_attr(node, 'starts', 'ints', default=[]))
ends = int64_array(onnx_attr(node, 'ends', 'ints', default=[]))
axes = int64_array(onnx_attr(node, 'axes', 'ints', default=[]))

if len(starts) == 0 or len(ends) == 0:
raise Error("starts or/and ends are not specified for the node {}".format(node.name))
if len(axes) == 0:
axes = np.arange(len(starts), dtype=np.int)

attrs = {'axes': axes, 'starts': starts, 'ends': ends}
AttributedSlice.update_node_stat(node, attrs)
else: # onnx_opset_version >= 10
Slice.update_node_stat(node)
return cls.enabled
75 changes: 0 additions & 75 deletions model-optimizer/extensions/front/onnx/slice_ext_test.py

This file was deleted.

Loading