From f6b646a8f1d4f7bcbf92571c8481042d955f6280 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Tue, 2 Feb 2021 14:21:30 +0300 Subject: [PATCH] fixed get_shape_from_slice and moved to common utils --- .../mo/front/common/partial_infer/utils.py | 29 +++++++++++++++++-- model-optimizer/mo/ops/slice.py | 29 +------------------ model-optimizer/mo/ops/strided_slice.py | 3 +- 3 files changed, 29 insertions(+), 32 deletions(-) diff --git a/model-optimizer/mo/front/common/partial_infer/utils.py b/model-optimizer/mo/front/common/partial_infer/utils.py index cbddae0ec292b1..9dbeef011d3997 100644 --- a/model-optimizer/mo/front/common/partial_infer/utils.py +++ b/model-optimizer/mo/front/common/partial_infer/utils.py @@ -15,7 +15,7 @@ """ import logging as log -from typing import Iterable +from typing import Iterable, List, Union import numpy as np @@ -109,4 +109,29 @@ def broadcast_shape(first_shape, second_shape): assert a_val == 1 or b_val == 1 or a_val == b_val, "Input shape do not broadcast" new_val = b_val if a_val == 1 else a_val new_shape[-i - 1] = new_val - return int64_array(new_shape) \ No newline at end of file + return int64_array(new_shape) + + +def get_shape_from_slice(input_shape: np.ndarray, slices: List[Union[int, slice]]) -> np.ndarray: + """ + Calculate shape of a tensor after slicing without actually creating the resulting tensor. + Is introduced to prevent potentially large memory consumption. + """ + output_shape = [] + in_idx, out_idx = 0, 0 + for i, s in enumerate(slices): + if isinstance(s, slice): + output_shape.append(len(range(*s.indices(input_shape[in_idx])))) + out_idx += 1 + in_idx += 1 + elif s is None: # new_axis + output_shape.append(1) + out_idx += 1 + elif isinstance(s, int): # shrink_axis + in_idx += 1 + else: + raise Exception('Element type of a slice List is unacceptable. ' + 'Allowed types are: slice, int, and None. Instead got: '. format(type(s))) + for i in range(in_idx, len(input_shape)): + output_shape.append(input_shape[i]) + return int64_array(output_shape) \ No newline at end of file diff --git a/model-optimizer/mo/ops/slice.py b/model-optimizer/mo/ops/slice.py index afabc616f312e3..381478c2457969 100644 --- a/model-optimizer/mo/ops/slice.py +++ b/model-optimizer/mo/ops/slice.py @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import List, Union import numpy as np +from mo.front.common.partial_infer.utils import get_shape_from_slice from mo.graph.graph import Node, Graph from mo.ops.op import Op from mo.utils.error import Error @@ -163,30 +163,3 @@ def infer(node: Node): node.out_port(0).data.set_shape(output_shape) else: node.out_port(0).data.set_value(input_value[tuple(slice_idx)]) - - -def get_shape_from_slice(input_shape: np.ndarray, slices: List[Union[int, slice]]) -> np.ndarray: - """ - Calculate shape of a tensor after slicing without actually creating the resulting tensor. - Is introduced to prevent potentially large memory consumption. - """ - out_rank = len(slices) - sum(map(lambda x: isinstance(x, int), slices)) - output_shape = np.zeros(out_rank, dtype=np.int64) - output_shape = [] - in_idx, out_idx = 0, 0 - for i, s in enumerate(slices): - if isinstance(s, slice): - output_shape.append(len(range(*s.indices(input_shape[in_idx])))) - out_idx += 1 - in_idx += 1 - elif s is None: # new_axis - output_shape.append(1) - out_idx += 1 - elif isinstance(s, int): # shrink_axis - in_idx += 1 - else: - raise Exception('Element type of a slice List is unacceptable. ' - 'Allowed types are: slice, int, and None. Instead got: '. format(type(s))) - for i in range(in_idx, len(input_shape)): - output_shape.append(input_shape[i]) - return output_shape diff --git a/model-optimizer/mo/ops/strided_slice.py b/model-optimizer/mo/ops/strided_slice.py index b7535a1870b3fe..70a0da24cfe9b5 100644 --- a/model-optimizer/mo/ops/strided_slice.py +++ b/model-optimizer/mo/ops/strided_slice.py @@ -16,12 +16,11 @@ import numpy as np -from mo.front.common.partial_infer.utils import int64_array +from mo.front.common.partial_infer.utils import get_shape_from_slice from mo.graph.graph import Node, Graph from mo.ops.op import Op from mo.utils.error import Error from mo.utils.utils import array_to_str -from mo.ops.slice import get_shape_from_slice class StridedSlice(Op):