forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reshape able slice (openvinotoolkit#1241)
* Added Caffe Slice_ext * Added TFSlice, AttributedSlice (both with extractors and replacers), corrected SliceConverter and added unittests for all cases * added comments to each type of Slice operation; optimized shape inference; moved mxlice inside of slice.py; renamed slice_replacers * removed type annotation for get_shape_after_slice routine * replaced zeros_like with zeros * Corrected preserving node names, renamed attributes names, added tests fro slice_replacer onnx phase * Renamed slice_replacers.py * added more unittest cases * added type annotations, moved to more relevant place routines for shape calculation, and some other minor corrections * corrected a typo `normalize_slice_indices` comment * corrected shape calculation for Nonconstant inputs * corrected a few typos * corrected type declarations * corrected shape inference with rounding * refactored unit-tests for front transforms of Slice * added error raising for negative and zero shapes * removed magic_num * corrected AttributedSlice, clarified comments * fixed unit-test for AttributedSliceToSlice * typo in type hints corrected * removed supported_attrs * returned back default None for attrs of Slice
- Loading branch information
1 parent
f7f32a3
commit 4eca16c
Showing
24 changed files
with
705 additions
and
919 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from mo.front.common.partial_infer.utils import int64_array | ||
from mo.front.extractor import FrontExtractorOp | ||
from mo.ops.slice import CaffeSlice | ||
|
||
|
||
class SliceFrontExtractor(FrontExtractorOp): | ||
op = 'slice' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
proto_layer = node.pb | ||
param = proto_layer.slice_param | ||
|
||
# slice_dim is deprecated parameter and is used as alias for axis | ||
# however if slice_dim is defined and axis is default, we use slice_dim | ||
if param.slice_dim != 1 and param.axis == 1: | ||
axis = param.slice_dim | ||
else: | ||
axis = param.axis | ||
|
||
update_attrs = { | ||
'axis': axis, | ||
'slice_point': int64_array(param.slice_point), | ||
'in_ports_count': 1, | ||
'out_ports_count': len(param.slice_point) + 1, | ||
} | ||
|
||
CaffeSlice.update_node_stat(node, update_attrs) | ||
return cls.enabled |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
model-optimizer/extensions/front/onnx/AttributedSliceToSlice.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from mo.front.common.replacement import FrontReplacementOp | ||
from mo.front.tf.graph_utils import create_op_with_const_inputs | ||
from mo.graph.graph import Graph, rename_nodes | ||
from mo.ops.slice import Slice | ||
|
||
|
||
class AttributedSliceToSliceReplacer(FrontReplacementOp): | ||
""" | ||
This class replaces AttributedSlice -> Slice | ||
""" | ||
op = 'AttributedSlice' | ||
enabled = True | ||
|
||
def replace_sub_graph(self, graph: Graph, match: dict): | ||
node = match['op'] | ||
slice_name = node.soft_get('name', node.id) | ||
|
||
slice_node = create_op_with_const_inputs(graph, Slice, {1: node.starts, 2: node.ends, 3: node.axes}) | ||
rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) | ||
|
||
node.in_port(0).get_connection().set_destination(slice_node.in_port(0)) | ||
node.out_port(0).get_connection().set_source(slice_node.out_port(0)) |
62 changes: 62 additions & 0 deletions
62
model-optimizer/extensions/front/onnx/AttributedSliceToSlice_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
from generator import generator, generate | ||
|
||
from extensions.front.onnx.AttributedSliceToSlice import AttributedSliceToSliceReplacer | ||
from mo.utils.ir_engine.compare_graphs import compare_graphs | ||
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front | ||
|
||
|
||
@generator | ||
class SliceReplacerTest(unittest.TestCase): | ||
@generate(*[ | ||
{'op': 'AttributedSlice', 'type': None, 'starts': np.array([0, 0]), 'ends': np.array([1, -1]), 'axes': np.array([0, 1])} | ||
]) | ||
def test_attributed_slice_replacer(self, attributed_slice_attrs): | ||
nodes = { | ||
**regular_op_with_empty_data('input', {'type': 'Parameter'}), | ||
**regular_op_with_empty_data('attributed_slice', attributed_slice_attrs), | ||
**result(), | ||
|
||
# nodes after replacement | ||
**const('start', np.array([0, 0])), | ||
**const('end', np.array([1, -1])), | ||
**const('axis', np.array(np.array([0, 1]))), | ||
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}), | ||
} | ||
|
||
graph = build_graph(nodes_attrs=nodes, edges=[ | ||
('input', 'attributed_slice'), | ||
('attributed_slice', 'output'), | ||
], nodes_with_edges_only=True) | ||
graph.stage = 'front' | ||
|
||
AttributedSliceToSliceReplacer().find_and_replace_pattern(graph) | ||
|
||
graph_ref = build_graph(nodes_attrs=nodes, edges=[ | ||
('input', 'slice'), | ||
*connect_front('start', '1:slice'), | ||
*connect_front('end', '2:slice'), | ||
*connect_front('axis', '3:slice'), | ||
('slice', 'output'), | ||
], nodes_with_edges_only=True) | ||
|
||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) | ||
self.assertTrue(flag, resp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.