Skip to content

Commit

Permalink
successfully converted xj_feauture and crash when loading with the ne…
Browse files Browse the repository at this point in the history
…w rewritten SS infer
  • Loading branch information
pavel-esir committed Feb 2, 2021
1 parent 90378af commit a19dda7
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 151 deletions.
3 changes: 1 addition & 2 deletions model-optimizer/extensions/back/CropToStridedSlice.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
'shrink_axis_mask': np.zeros(len(end_mask)),
'ellipsis_mask': np.zeros(len(end_mask))}).create_node()


if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
Expand Down Expand Up @@ -116,7 +115,7 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
source = node.in_port(0).get_connection().get_source()

stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
'name': ss.name + '/stride'}).create_node()
'name': ss.name + '/stride'}).create_node()

source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))
Expand Down
3 changes: 2 additions & 1 deletion model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def permute_input_data(graph: Graph):

if permutation_data_node.has_and_set('permutation') and \
not is_input_data_in_correct_layout(node, in_port) and \
len(port_to_check.data.get_shape()) >= 4:
(len(port_to_check.data.get_shape()) >= 4 or \
(node['type'] == 'StridedSlice' and len(node.in_port(1).data.get_shape()) >= 4)):
permutation(node, port_info, in_port)
if node.has_and_set('need_shape_inference'):
node.infer(node)
Expand Down
4 changes: 2 additions & 2 deletions model-optimizer/extensions/middle/StridedSliceNormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def normalize_strided_slice(self, graph: Graph, node: Node):

if np.any(node.ellipsis_mask):
idx = np.nonzero(node.ellipsis_mask)
assert len(idx[0]) == 1
assert len(idx[0]) == 1, 'only one ellipsis_mask nonzero value is allowed'
ellipsis_start = idx[0][0]
num_inserts = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
num_inserts = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask[ellipsis_start:])

node.begin_mask[ellipsis_start] = 0
node.end_mask[ellipsis_start] = 0
Expand Down
22 changes: 8 additions & 14 deletions model-optimizer/extensions/middle/StridedSliceNormalizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,9 @@

nodes = {
**valued_const_with_data('input', input),
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice',
'begin_mask': [1, 1, 1],
'end_mask': [1, 1, 1],
'ellipsis_mask': [0, 0, 0],
'new_axis_mask': [0, 0, 0],
'shrink_axis_mask': [0, 0, 0],
'infer': StridedSlice.infer}),
**regular_op_with_empty_data('strided_slice_normalized', {'op': 'StridedSlice',
'begin_mask': [0, 1, 1, 0],
'end_mask': [0, 1, 1, 0],
'ellipsis_mask': [0, 0, 0, 0],
'new_axis_mask': [0, 0, 0, 0],
'shrink_axis_mask': [0, 0, 0, 0],
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice', 'begin_mask': [1, 1, 1],
'end_mask': [1, 1, 1], 'ellipsis_mask': [0, 0, 0],
'new_axis_mask': [0, 0, 0], 'shrink_axis_mask': [0, 0, 0],
'infer': StridedSlice.infer}),
**valued_const_with_data('begin', int64_array(begin)),
**valued_const_with_data('begin_placeholder', int64_array([0])),
Expand Down Expand Up @@ -94,7 +84,7 @@
*connect('strided_slice', 'res')
)

edges_ref_extend = (
edges_ref_ellipsis_unrolled = (
*connect('input', '0:strided_slice'),

# after extending begin
Expand All @@ -113,6 +103,10 @@
*connect('strided_slice', 'res')
)

# unrolled and extended

# extended with existing concat


class TestStridedSliceNormalizer(unittest.TestCase):

Expand Down
3 changes: 1 addition & 2 deletions model-optimizer/mo/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ class PermuteAttrs:
Attr = namedtuple('Attr', ['name', 'port', 'func'])

common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
slice_permutation = lambda node, permutation, attr: \
node[attr][permutation.perm] if len(node.in_port(0).data.get_shape()) >= 4 else node[attr]
slice_permutation = lambda node, permutation, attr: node[attr][PermuteAttrs.get_nhwc_to_nchw_permutation(len(node[attr])).perm]
common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]

# List of default permutations
Expand Down
28 changes: 22 additions & 6 deletions model-optimizer/mo/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List
from typing import List, Union

import numpy as np

Expand Down Expand Up @@ -157,20 +157,36 @@ def infer(node: Node):
# Ranged for output value for specified axis
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
if input_value is None:
output_shape = get_shape_after_slice(input_shape, slice_idx)
output_shape = get_shape_from_slice(input_shape, slice_idx)
if np.any(output_shape <= 0):
raise Error('Output shape: {} of node "{}" contains non-positive values'.format(output_shape, node.name))
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])


def get_shape_after_slice(input_shape: np.ndarray, slice_idx: List[slice]) -> np.ndarray:
def get_shape_from_slice(input_shape: np.ndarray, slices: List[Union[int, slice]]) -> np.ndarray:
"""
Calculate shape of a tensor after slicing without actually creating the resulting tensor.
Is introduced to prevent potentially large memory consumption.
"""
output_shape = np.zeros(len(input_shape), dtype=np.int32)
for i, s in enumerate(slice_idx):
output_shape[i] = len(range(*s.indices(input_shape[i])))
out_rank = len(slices) - sum(map(lambda x: isinstance(x, int), slices))
output_shape = np.zeros(out_rank, dtype=np.int64)
output_shape = []
in_idx, out_idx = 0, 0
for i, s in enumerate(slices):
if isinstance(s, slice):
output_shape.append(len(range(*s.indices(input_shape[in_idx]))))
out_idx += 1
in_idx += 1
elif s is None: # new_axis
output_shape.append(1)
out_idx += 1
elif isinstance(s, int): # shrink_axis
in_idx += 1
else:
raise Exception('Element type of a slice List is unacceptable. '
'Allowed types are: slice, int, and None. Instead got: '. format(type(s)))
for i in range(in_idx, len(input_shape)):
output_shape.append(input_shape[i])
return output_shape
152 changes: 70 additions & 82 deletions model-optimizer/mo/ops/strided_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mo.ops.op import Op
from mo.utils.error import Error
from mo.utils.utils import array_to_str
from mo.ops.slice import get_shape_from_slice


class StridedSlice(Op):
Expand Down Expand Up @@ -61,98 +62,85 @@ def infer(node: Node):
def tf_strided_slice_infer(node):
if node.in_node(1).value is None or node.in_node(2).value is None:
raise Error('Strided slice layer supports only constant begin and end inputs')
begin_id = node.in_node(1).value.copy()
end_id = node.in_node(2).value.copy()
begin_id = node.in_port(1).data.get_value()
end_id = node.in_port(2).data.get_value()

if len(node.in_nodes()) > 3:
if node.in_node(3).value is None:
if node.in_port(3).data.get_value() is None:
raise Error('Strided slice layer supports only constant stride input')
stride = node.in_node(3).value
strides = node.in_port(3).data.get_value()
else:
stride = []

strides = np.ones_like(begin_id)
shape = node.in_node(0).shape
value = node.in_port(0).data.get_value()
input_rank = len(shape)

if shape is None or any([x < 0 for x in shape]):
return

convert_negative_indices(begin_id, shape)
convert_negative_indices(end_id, shape)

slice_idx = []
dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
len(node.begin_mask), len(node.end_mask)]))

# make mask correct length
def extend_mask(in_mask, fin_len, zeros=True):
mask = list(in_mask)
if len(mask) < fin_len:
if zeros:
mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
else:
mask.extend(np.ones(dims-len(mask), dtype=np.int32))
return np.array(mask, dtype=np.int32)

new_axis_mask = extend_mask(node.new_axis_mask, dims)
shrink_axis_mask = extend_mask(node.shrink_axis_mask, dims)
ellipsis_mask = extend_mask(node.ellipsis_mask, dims)
begin_mask = extend_mask(node.begin_mask, dims)
end_mask = extend_mask(node.end_mask, dims)

old_idx = 0
ellips_ext = 0
id_em = 0
for idx in range(dims):
if new_axis_mask[idx]:
slice_idx.append(np.newaxis)
elif ellipsis_mask[idx]:
ellips_ext = len(shape) - (dims - np.count_nonzero(new_axis_mask) - 1)
id_em = idx
for i in range(0, ellips_ext):
slice_idx.append(slice(0, shape[old_idx], 1))
old_idx = old_idx + 1
assert len(begin_id) == len(end_id) == len(strides), 'begin, end, and strides must be of the same length'

extend_mask = lambda mask: np.append(mask, [0] * (len(begin_id) - len(mask)))
new_axis_mask = extend_mask(node.new_axis_mask)
shrink_axis_mask = extend_mask(node.shrink_axis_mask)
begin_mask = extend_mask(node.begin_mask)
end_mask = extend_mask(node.end_mask)

# unroll ellipsis
if np.any(node.ellipsis_mask):
i = np.nonzero(node.ellipsis_mask)
assert len(i[0]) == 1, 'only one nonzero value in ellipsis_mask is allowed'
ellipsis_start = i[0][0]
num_inserts = input_rank - len(begin_id) + np.count_nonzero(node.new_axis_mask[ellipsis_start:])

# since we don't unse begin, end value
begin_mask[ellipsis_start] = 0
end_mask[ellipsis_start] = 0
new_axis_mask = np.insert(new_axis_mask, ellipsis_start + 1, [0] * num_inserts)
shrink_axis_mask = np.insert(shrink_axis_mask, ellipsis_start + 1, [0] * num_inserts)
begin_mask = np.insert(begin_mask, ellipsis_start + 1, [0] * num_inserts)
end_mask = np.insert(end_mask, ellipsis_start + 1, [0] * num_inserts)

begin_id = np.insert(end_id, ellipsis_start + 1, [0] * num_inserts)
end_id = np.insert(end_id, ellipsis_start + 1, [0] * num_inserts)
strides = np.insert(strides, ellipsis_start + 1, [1] * num_inserts)

# from now slices are without ellipsis
dims = len(begin_id)
slice_idx = [[]] * dims
in_idx = 0
for i in range(dims):
if new_axis_mask[i]:
slice_idx[i] = np.newaxis
elif shrink_axis_mask[i]:
begin = begin_id[in_idx]
if begin < 0:
begin += shape[in_idx]
slice_idx[i] = int(begin)
else:
s = stride[idx] if len(stride) > idx else 1
def_beg = 0 if s > 0 else -1
def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
l = begin_id[idx] if begin_mask[idx] and idx < len(begin_id) else def_beg
r = end_id[idx] if end_mask[idx] and idx < len(end_id) else def_end

# Check shrink_axis_mask
if shrink_axis_mask[idx] and idx < len(shape):
slice_idx.append(slice(l, l+1, s))
else:
slice_idx.append(slice(l, r, s))
old_idx = old_idx + 1

value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
value = value[tuple(slice_idx)]

for idx, flag in reversed(list(enumerate(shrink_axis_mask))):
if flag:
if ellips_ext > 0 and idx > id_em:
idx = idx + ellips_ext - 1
try:
value = np.squeeze(value, idx)
except ValueError:
# ignore this error
continue

for i, s in enumerate(slice_idx):
if s is None:
slice_idx[i] = slice(0, 1, 1)
begin = begin_id[in_idx]
end = end_id[in_idx]
if not begin_mask[i]:
begin = 0 if strides[in_idx] > 0 else -1
if not end_mask[i]:
end = shape[in_idx] if strides[in_idx] > 0 else -shape[in_idx] - 1
slice_idx[i] = slice(begin, end, strides[in_idx])
in_idx += 1 if not new_axis_mask[i] else 0

if value is not None:
node.out_port(0).data.set_value(value[tuple(slice_idx)])
else:
node.out_port(0).data.set_shape(get_shape_from_slice(shape, slice_idx))

in_idx = 0
for i in range(dims):
if new_axis_mask[i]:
slice_idx[i] = slice(0, 1, 1)
elif shrink_axis_mask[i]:
slice_idx[i] = slice(slice_idx[i], slice_idx[i] + 1, strides[i])
if not new_axis_mask[i]:
slice_idx[i] = slice(*slice_idx[i].indices(shape[in_idx])) # will convert negative indices
in_idx += 1
node['slices'] = np.array(slice_idx)
for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
node[attr] = np.array(node[attr], dtype=np.int32)

node['force_precision_in_ports'] = {port: 'int64' for port in range(1, len(node.in_nodes()))}

node.out_node().value = value.copy() if node.in_node(0).value is not None else None
node.out_node().shape = np.array(value.shape, dtype=np.int64)


def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]
Loading

0 comments on commit a19dda7

Please sign in to comment.