Skip to content

Commit

Permalink
Reshape-able SliceConverter (openvinotoolkit#3198)
Browse files Browse the repository at this point in the history
* initial commit

* add cast

* data type fix

* added tests

* added test without axes and steps

* remove redundant imports

* discussions resolving

* Add cast to TFSliceToSlice

* layer tests fix

* update unittest

* rework transformation

* added clamp

* move broadcast

* update unittests

* failed e2e fix

* added comment

* little fixes

* comments update
  • Loading branch information
yekruglov authored and jiwaszki committed Jan 15, 2021
1 parent 7a3227f commit 6c211ec
Show file tree
Hide file tree
Showing 4 changed files with 466 additions and 369 deletions.
4 changes: 4 additions & 0 deletions model-optimizer/extensions/front/tf/TFSliceToSlice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Add, Equal
from extensions.ops.select import Select
from mo.front.common.replacement import FrontReplacementOp
Expand Down Expand Up @@ -74,4 +75,7 @@ def replace_sub_graph(self, graph: Graph, match: dict):
# out of select to end (2nd of slice)
select_node.out_port(0).connect(slice_node.in_port(2))

cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
select_node.in_port(2).get_connection().insert_node(cast)

node.out_port(0).get_connection().set_source(slice_node.out_port(0))
7 changes: 5 additions & 2 deletions model-optimizer/extensions/front/tf/TFSliceToSlice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
**regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}),
}


Expand Down Expand Up @@ -68,7 +69,8 @@ def test_slice_replacer_begin_with_2_inputs(self):

*connect_front('equal:0', 'select:0'),

*connect_front('end_const:0', 'select:2'),
*connect_front('end_const:0', 'cast:0'),
*connect_front('cast:0', 'select:2'),
*connect_front('select:0', 'slice:2'),

*connect_front('slice:0', 'output'),
Expand Down Expand Up @@ -97,7 +99,8 @@ def test_slice_replacer(self):
*connect_front('int32_max:0', '1:select'),
*connect_front('minus_one:0', '1:equal'),
*connect_front('equal:0', '0:select'),
*connect_front('end_const:0', '2:select'),
*connect_front('end_const:0', '0:cast'),
*connect_front('cast:0', '2:select'),
*connect_front('select:0', '2:slice'),
*connect_front('slice:0', 'output'),
], nodes_with_edges_only=True)
Expand Down
177 changes: 97 additions & 80 deletions model-optimizer/extensions/middle/SliceConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,100 +16,117 @@

import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.gather import Gather
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.clamp import Clamp
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.strided_slice import StridedSlice
from mo.utils.error import Error


def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str):
"""
This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"
:param graph: graph to operate on.
:param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
:param shape: input shape of the Slice
:param axes: axes that "starts" and "ends" apply to
:param node_name: Slice node name
:return: Concat node that forms "begin"/"end" values for the StridedSlice
"""
# the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
# value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
# supported by the StridedSlice layer
clamp = create_op_with_const_inputs(
graph, Clamp, port_value_dict={1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max},
op_attrs=dict(name=node_name + '/Clamp'))
clamp.in_port(0).connect(slice_border_port)
# we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
# here to prevent type errors in Concat node
cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node()
cast.in_port(0).connect(clamp.out_port(0))
concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node()
for value_idx, port_idx in enumerate(axes):
concat.add_input_port(port_idx)
# "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
# Concat input port
value = create_op_with_const_inputs(
graph, Gather, port_value_dict={1: int64_array([value_idx]), 2: int64_array(0)},
op_attrs={'name': node_name + '/Gather'})
cast.out_port(0).connect(value.in_port(0))
value.out_port(0).connect(concat.in_port(port_idx))
for port_idx in range(len(shape)):
if not concat.is_in_port_connected(port_idx):
concat.add_input_port(port_idx)
# This border value would be ignored in StridedSlice because of the begin_mask\end_mask
const = Const(graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node()
const.out_port(0).connect(concat.in_port(port_idx))

return concat


class ConvertSlice(MiddleReplacementPattern):
"""
This class converts Slice operation to StridedSlice
This class converts a Slice operation to StridedSlice in reshape-able way by parsing the 'starts' and 'ends'
parameters based on the 'axes' parameter
"""

enabled = True
op = "Slice"
force_clean_up = True

def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]

def pattern(self):
return dict(
nodes=[
('slice', dict(kind='op', op='Slice'))
],
edges=[]
)

def replace_pattern(self, graph: Graph, match: dict):
node = match['slice']

input_shape = node.in_port(0).data.get_shape()
output_shape = node.out_port(0).data.get_shape()
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if starts is None or ends is None:
raise Error('The input with starts or end is not constant for node {}'.format(node.id))

# the value for 'ends' is usually maximum possible value of int64. This
# value must be converted to maximum of int32 because such big values do not fit into the int32 which is
# supported by the StridedSlice layer
ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value()
if axes is None:
raise Error('The input with axes is not constant for node {}'.format(node.id))
else:
axes = int64_array(list(range(starts.size)))

if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
if steps is None:
raise Error('The input with steps is not constant for node {}'.format(node.id))
else:
steps = np.ones([starts.size])

ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_begin = np.zeros(len(input_shape), dtype=np.int32)
ss_end = np.zeros(len(input_shape), dtype=np.int32)
ss_step = np.ones(len(input_shape), dtype=np.int32)

# prepare inputs and attributes for the StridedSlice layer
for i, axis in enumerate(axes):
if starts[i] != 0:
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Slice'):
node_name = node.soft_get('name', node.id)

input_shape = node.in_port(0).data.get_shape()
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value().copy()
assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
for i, val in enumerate(axes):
axes[i] = get_canonical_axis_index(input_shape, val)
else:
axes = int64_array(range(len(input_shape)))

ss_begin = create_ss_interval_border(graph, node.in_port(1).get_source(), input_shape, axes, node_name)
ss_end = create_ss_interval_border(graph, node.in_port(2).get_source(), input_shape, axes, node_name)
node.in_port(1).disconnect()
node.in_port(2).disconnect()
rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])

if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
else:
steps = np.ones([axes.size])

ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_step = np.ones(len(input_shape), dtype=np.int64)

for i, axis in enumerate(axes):
ss_begin_mask[axis] = 1
ss_begin[axis] = starts[i]

ss_end_mask[axis] = 1
ss_end[axis] = ends[i]

ss_step[axis] = steps[i]

slice_node_name = node.soft_get('name', node.id)

begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()

ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask)).create_node()
rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)])
node.in_port(0).get_connection().set_destination(ss.in_port(0))
begin_node.out_port(0).connect(ss.in_port(1))
end_node.out_port(0).connect(ss.in_port(2))
strides_node.out_port(0).connect(ss.in_port(3))
node.out_port(0).get_connection().set_source(ss.out_port(0))
ss_end_mask[axis] = 1
ss_step[axis] = steps[i]

ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()

ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask)).create_node()

node.in_port(0).get_connection().set_destination(ss.in_port(0))
ss.in_port(1).connect(ss_begin.out_port(0))
ss.in_port(2).connect(ss_end.out_port(0))
ss.in_port(3).connect(ss_strides.out_port(0))
node.out_port(0).get_connection().set_source(ss.out_port(0))

rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
Loading

0 comments on commit 6c211ec

Please sign in to comment.