Skip to content

Commit

Permalink
fixed get_shape_from_slice and moved to common utils
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Feb 2, 2021
1 parent a19dda7 commit f6b646a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
29 changes: 27 additions & 2 deletions model-optimizer/mo/front/common/partial_infer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import logging as log
from typing import Iterable
from typing import Iterable, List, Union

import numpy as np

Expand Down Expand Up @@ -109,4 +109,29 @@ def broadcast_shape(first_shape, second_shape):
assert a_val == 1 or b_val == 1 or a_val == b_val, "Input shape do not broadcast"
new_val = b_val if a_val == 1 else a_val
new_shape[-i - 1] = new_val
return int64_array(new_shape)
return int64_array(new_shape)


def get_shape_from_slice(input_shape: np.ndarray, slices: List[Union[int, slice]]) -> np.ndarray:
"""
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
Is introduced to prevent potentially large memory consumption.
"""
output_shape = []
in_idx, out_idx = 0, 0
for i, s in enumerate(slices):
if isinstance(s, slice):
output_shape.append(len(range(*s.indices(input_shape[in_idx]))))
out_idx += 1
in_idx += 1
elif s is None: # new_axis
output_shape.append(1)
out_idx += 1
elif isinstance(s, int): # shrink_axis
in_idx += 1
else:
raise Exception('Element type of a slice List is unacceptable. '
'Allowed types are: slice, int, and None. Instead got: '. format(type(s)))
for i in range(in_idx, len(input_shape)):
output_shape.append(input_shape[i])
return int64_array(output_shape)
29 changes: 1 addition & 28 deletions model-optimizer/mo/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List, Union

import numpy as np

from mo.front.common.partial_infer.utils import get_shape_from_slice
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.error import Error
Expand Down Expand Up @@ -163,30 +163,3 @@ def infer(node: Node):
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])


def get_shape_from_slice(input_shape: np.ndarray, slices: List[Union[int, slice]]) -> np.ndarray:
"""
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
Is introduced to prevent potentially large memory consumption.
"""
out_rank = len(slices) - sum(map(lambda x: isinstance(x, int), slices))
output_shape = np.zeros(out_rank, dtype=np.int64)
output_shape = []
in_idx, out_idx = 0, 0
for i, s in enumerate(slices):
if isinstance(s, slice):
output_shape.append(len(range(*s.indices(input_shape[in_idx]))))
out_idx += 1
in_idx += 1
elif s is None: # new_axis
output_shape.append(1)
out_idx += 1
elif isinstance(s, int): # shrink_axis
in_idx += 1
else:
raise Exception('Element type of a slice List is unacceptable. '
'Allowed types are: slice, int, and None. Instead got: '. format(type(s)))
for i in range(in_idx, len(input_shape)):
output_shape.append(input_shape[i])
return output_shape
3 changes: 1 addition & 2 deletions model-optimizer/mo/ops/strided_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.partial_infer.utils import get_shape_from_slice
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.error import Error
from mo.utils.utils import array_to_str
from mo.ops.slice import get_shape_from_slice


class StridedSlice(Op):
Expand Down

0 comments on commit f6b646a

Please sign in to comment.