Skip to content

Commit

Permalink
successfully moved PermuteAttrs to general mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Feb 2, 2021
1 parent e309f23 commit 90378af
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 50 deletions.
1 change: 1 addition & 0 deletions model-optimizer/extensions/back/CropToStridedSlice.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):

in_shape = node.in_port(0).data.get_shape()
shape_rank = in_shape.size
# dangerous place
axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
begin_mask = axis_mask.copy()
end_mask = axis_mask.copy()
Expand Down
13 changes: 10 additions & 3 deletions model-optimizer/extensions/middle/StridedSliceNormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.graph.perm_inputs import PermuteInputs
from mo.ops.op import Op, PermuteAttrs


class StridedSliceNormalizer(MiddleReplacementPattern):
enabled = True
Expand All @@ -35,10 +37,16 @@ def find_and_replace_pattern(self, graph: Graph):
ss_nodes = graph.get_op_nodes(op='StridedSlice')
for node in ss_nodes:
self.normalize_strided_slice(graph, node)

PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0'),
('new_axis_mask', 'input:0'),
('ellipsis_mask', 'input:0'),
('begin_mask', 'input:0'),
('end_mask', 'input:0')])

PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')
pass

def normalize_strided_slice(self, graph: Graph, node: Node):
input_shape = node.in_port(0).data.get_shape()
Expand All @@ -61,9 +69,8 @@ def normalize_strided_slice(self, graph: Graph, node: Node):
node.ellipsis_mask[ellipsis_start] = 0

self.unroll_ellipsis_for_inputs(graph, node, ellipsis_start, num_inserts)

elif slice_rank < input_rank: # process somehow nonzero
num = input_rank - (slice_rank)
num = input_rank - slice_rank
# extend masks
for mask_name in ['begin_mask', 'end_mask', 'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask']:
node[mask_name] = np.append(node[mask_name], [0] * num)
Expand Down
9 changes: 7 additions & 2 deletions model-optimizer/mo/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ class PermuteAttrs:
Attr = namedtuple('Attr', ['name', 'port', 'func'])

common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
slice_permutation = lambda node, permutation, attr: \
node[attr][permutation.perm] if len(node.in_port(0).data.get_shape()) >= 4 else node[attr]
common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]

# List of default permutations
Expand All @@ -355,8 +357,11 @@ class PermuteAttrs:
'kernel_shape': common_permutation,
'output_shape': common_permutation,
'slices': common_permutation,
'shrink_axis_mask': common_permutation,
'new_axis_mask': common_permutation,
'begin_mask': slice_permutation,
'end_mask': slice_permutation,
'shrink_axis_mask': slice_permutation,
'new_axis_mask': slice_permutation,
'ellipsis_mask': slice_permutation,
'axes': common_permutation_inv,
'axis': common_permutation_inv,
'batch_dims': common_permutation_inv,
Expand Down
46 changes: 1 addition & 45 deletions model-optimizer/mo/ops/strided_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.graph.perm_inputs import PermuteInputs
from mo.ops.op import Op, PermuteAttrs
from mo.ops.op import Op
from mo.utils.error import Error
from mo.utils.utils import array_to_str

Expand Down Expand Up @@ -58,20 +57,6 @@ def infer(node: Node):
# FW assures that begin, end, and strides are of the same length th
tf_strided_slice_infer(node)

out_shape = node.out_port(0).data.get_shape()
assert out_shape is not None, \
'Output shape was not calculated for node {}'.format(node.name)

PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks),
('new_axis_mask', 'input:0', permute_masks),
('ellipsis_mask', 'input:0', permute_masks),
('begin_mask', 'input:0', permute_masks),
('end_mask', 'input:0', permute_masks)])

# PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape')
# PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
# PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')


def tf_strided_slice_infer(node):
if node.in_node(1).value is None or node.in_node(2).value is None:
Expand Down Expand Up @@ -171,32 +156,3 @@ def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]

def permute_array(node: Node, array: np.array):
"""
This function permutes masks according to permutation parameter. Mask have the same or more length than output
"""
attr_mask_extended = list(array)

# If input and output have length of shape 3 and less, no need to permute
if len(node.in_port(0).data.get_shape()) < 4 and len(node.out_port(0).data.get_shape()) < 4:
return attr_mask_extended

perm_len = len(node.out_port(0).data.get_shape()) + np.count_nonzero(node.shrink_axis_mask)
perm_len = len(array)

perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len)
perm_list = list(perm.perm)
# if mask length is more than output, just add tail that will not be permuted to avoid error
for i in range(perm_len, len(attr_mask_extended)):
perm_list.append(i)
return int64_array(attr_mask_extended)[int64_array(perm_list)]


def permute_masks(node: Node, permutation: PermuteAttrs.Permutation, attr: str):
if not node.has_valid(attr):
return None

node[attr] = permute_array(node, node[attr])
return node[attr]

0 comments on commit 90378af

Please sign in to comment.