Skip to content

Commit

Permalink
Reshape able slice (openvinotoolkit#1241)
Browse files Browse the repository at this point in the history
* Added Caffe Slice_ext

* Added TFSlice, AttributedSlice (both with extractors and replacers), corrected SliceConverter and added unittests for all cases

* added comments to each type of Slice operation; optimized shape inference; moved mxlice inside of slice.py; renamed slice_replacers

* removed type annotation for get_shape_after_slice routine

* replaced zeros_like with zeros

* Corrected preserving node names, renamed attributes names, added tests fro slice_replacer onnx phase

* Renamed slice_replacers.py

* added more unittest cases

* added type annotations, moved to more relevant place routines for shape calculation, and some other minor corrections

* corrected a typo `normalize_slice_indices` comment

* corrected shape calculation for Nonconstant inputs

* corrected a few typos

* corrected type declarations

* corrected shape inference with rounding

* refactored unit-tests for front transforms of Slice

* added error raising for negative and zero shapes

* removed magic_num

* corrected AttributedSlice, clarified comments

* fixed unit-test for AttributedSliceToSlice

* typo in type hints corrected

* removed supported_attrs

* returned back default None for attrs of Slice
  • Loading branch information
pavel-esir authored and mryzhov committed Aug 26, 2020
1 parent 4db1bba commit f8615fa
Show file tree
Hide file tree
Showing 24 changed files with 705 additions and 919 deletions.
5 changes: 3 additions & 2 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ extensions/front/caffe/reshape.py
extensions/front/caffe/shufflechannel_ext.py
extensions/front/caffe/sigmoid.py
extensions/front/caffe/simplernms_ext.py
extensions/front/caffe/slice_ext.py
extensions/front/caffe/slice_to_split.py
extensions/front/caffe/softmax_ext.py
extensions/front/caffe/spatial_transformer_ext.py
Expand Down Expand Up @@ -231,6 +232,7 @@ extensions/front/onnx/activation_ext.py
extensions/front/onnx/affine_ext.py
extensions/front/onnx/argmax_ext.py
extensions/front/onnx/aten_ext.py
extensions/front/onnx/AttributedSliceToSlice.py
extensions/front/onnx/cast_ext.py
extensions/front/onnx/clip_ext.py
extensions/front/onnx/const_ext.py
Expand Down Expand Up @@ -447,6 +449,7 @@ extensions/front/tf/SwitchMergeOptimization.py
extensions/front/tf/TensorArrayExtractors.py
extensions/front/tf/TensorArrayGatherV3.py
extensions/front/tf/tensorflow_custom_operations_config_update.py
extensions/front/tf/TFSliceToSlice.py
extensions/front/tf/tile_ext.py
extensions/front/tf/topk_ext.py
extensions/front/tf/transpose_ext.py
Expand Down Expand Up @@ -624,7 +627,6 @@ extensions/ops/merge.py
extensions/ops/mvn.py
extensions/ops/mxrepeat.py
extensions/ops/mxreshape.py
extensions/ops/mxslice.py
extensions/ops/NextIteration.py
extensions/ops/non_max_suppression.py
extensions/ops/non_zero.py
Expand Down Expand Up @@ -723,7 +725,6 @@ mo/front/caffe/extractors/crop.py
mo/front/caffe/extractors/native_caffe.py
mo/front/caffe/extractors/roipooling.py
mo/front/caffe/extractors/scale.py
mo/front/caffe/extractors/slice.py
mo/front/caffe/extractors/tile.py
mo/front/caffe/extractors/utils.py
mo/front/caffe/loader.py
Expand Down
46 changes: 46 additions & 0 deletions model-optimizer/extensions/front/caffe/slice_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
from mo.ops.slice import CaffeSlice


class SliceFrontExtractor(FrontExtractorOp):
op = 'slice'
enabled = True

@classmethod
def extract(cls, node):
proto_layer = node.pb
param = proto_layer.slice_param

# slice_dim is deprecated parameter and is used as alias for axis
# however if slice_dim is defined and axis is default, we use slice_dim
if param.slice_dim != 1 and param.axis == 1:
axis = param.slice_dim
else:
axis = param.axis

update_attrs = {
'axis': axis,
'slice_point': int64_array(param.slice_point),
'in_ports_count': 1,
'out_ports_count': len(param.slice_point) + 1,
}

CaffeSlice.update_node_stat(node, update_attrs)
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/caffe/slice_to_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class SliceToVariadicSplit(FrontReplacementOp):
op = "Slice"
op = "CaffeSlice"
enabled = True

def replace_sub_graph(self, graph: Graph, match: dict):
Expand All @@ -37,7 +37,7 @@ def replace_sub_graph(self, graph: Graph, match: dict):
return

assert node.has_valid('slice_point'), 'Slice operation `{}` has no `slice_point` parameter'.format(name)
slice_point = np.array(node.slice_point)
slice_point = node.slice_point

if slice_point.size == 0:
num_splits = len(node.out_ports())
Expand Down
2 changes: 1 addition & 1 deletion model-optimizer/extensions/front/mxnet/slice_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import numpy as np

from extensions.ops.mxslice import MXSlice
from mo.front.extractor import FrontExtractorOp
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.ops.slice import MXSlice


class SliceFrontExtractor(FrontExtractorOp):
Expand Down
2 changes: 1 addition & 1 deletion model-optimizer/extensions/front/mxnet/slice_replacers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mo.ops.strided_slice import StridedSlice


class SliceFrontReplacer(FrontReplacementOp):
class MXSliceToStridedSliceReplacer(FrontReplacementOp):
op = 'MXSlice'
enabled = True

Expand Down
38 changes: 38 additions & 0 deletions model-optimizer/extensions/front/onnx/AttributedSliceToSlice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.slice import Slice


class AttributedSliceToSliceReplacer(FrontReplacementOp):
"""
This class replaces AttributedSlice -> Slice
"""
op = 'AttributedSlice'
enabled = True

def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
slice_name = node.soft_get('name', node.id)

slice_node = create_op_with_const_inputs(graph, Slice, {1: node.starts, 2: node.ends, 3: node.axes})
rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

node.in_port(0).get_connection().set_destination(slice_node.in_port(0))
node.out_port(0).get_connection().set_source(slice_node.out_port(0))
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np
from generator import generator, generate

from extensions.front.onnx.AttributedSliceToSlice import AttributedSliceToSliceReplacer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front


@generator
class SliceReplacerTest(unittest.TestCase):
@generate(*[
{'op': 'AttributedSlice', 'type': None, 'starts': np.array([0, 0]), 'ends': np.array([1, -1]), 'axes': np.array([0, 1])}
])
def test_attributed_slice_replacer(self, attributed_slice_attrs):
nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('attributed_slice', attributed_slice_attrs),
**result(),

# nodes after replacement
**const('start', np.array([0, 0])),
**const('end', np.array([1, -1])),
**const('axis', np.array(np.array([0, 1]))),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
}

graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'attributed_slice'),
('attributed_slice', 'output'),
], nodes_with_edges_only=True)
graph.stage = 'front'

AttributedSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
('input', 'slice'),
*connect_front('start', '1:slice'),
*connect_front('end', '2:slice'),
*connect_front('axis', '3:slice'),
('slice', 'output'),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
32 changes: 18 additions & 14 deletions model-optimizer/extensions/front/onnx/slice_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import get_onnx_opset_version
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.slice import Slice
from mo.ops.slice import Slice, AttributedSlice
from mo.utils.error import Error


class SliceFrontExtractor(FrontExtractorOp):
Expand All @@ -27,17 +30,18 @@ class SliceFrontExtractor(FrontExtractorOp):

@classmethod
def extract(cls, node):
axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]), dtype=np.int64)
start = np.array(onnx_attr(node, 'starts', 'ints', default=[]), dtype=np.int64)
end = np.array(onnx_attr(node, 'ends', 'ints', default=[]), dtype=np.int64)

attrs = {
'axis': axis if len(axis) != 0 else None,
'start': start if len(start) != 0 else None,
'end': end if len(end) != 0 else None,
'format': 'onnx'
}

# update the attributes of the node
Slice.update_node_stat(node, attrs)
if get_onnx_opset_version(node) < 10:
starts = int64_array(onnx_attr(node, 'starts', 'ints', default=[]))
ends = int64_array(onnx_attr(node, 'ends', 'ints', default=[]))
axes = int64_array(onnx_attr(node, 'axes', 'ints', default=[]))

if len(starts) == 0 or len(ends) == 0:
raise Error("starts or/and ends are not specified for the node {}".format(node.name))
if len(axes) == 0:
axes = np.arange(len(starts), dtype=np.int)

attrs = {'axes': axes, 'starts': starts, 'ends': ends}
AttributedSlice.update_node_stat(node, attrs)
else: # onnx_opset_version >= 10
Slice.update_node_stat(node)
return cls.enabled
75 changes: 0 additions & 75 deletions model-optimizer/extensions/front/onnx/slice_ext_test.py

This file was deleted.

Loading

0 comments on commit f8615fa

Please sign in to comment.