Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reshape-able SliceConverter #3198

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
315a309
initial commit
yekruglov Oct 30, 2020
cab3992
add cast
yekruglov Oct 30, 2020
c5fd06a
data type fix
yekruglov Nov 2, 2020
d288f98
added tests
yekruglov Nov 3, 2020
d1107b8
added test without axes and steps
yekruglov Nov 3, 2020
1233d05
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 3, 2020
469ad55
remove redundant imports
yekruglov Nov 3, 2020
ddc4df1
discussions resolving
yekruglov Nov 5, 2020
d4ec7de
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 5, 2020
f68f06b
Add cast to TFSliceToSlice
yekruglov Nov 9, 2020
b30f952
layer tests fix
yekruglov Nov 9, 2020
51643f3
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 9, 2020
2de45f9
update unittest
yekruglov Nov 9, 2020
4c52c6b
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 10, 2020
6045113
rework transformation
yekruglov Nov 16, 2020
ac82b30
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 16, 2020
30ae90f
added clamp
yekruglov Nov 18, 2020
9871095
move broadcast
yekruglov Nov 18, 2020
30bb779
update unittests
yekruglov Nov 18, 2020
1db4adc
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 19, 2020
80a0d13
failed e2e fix
yekruglov Nov 19, 2020
1e24765
added comment
yekruglov Nov 20, 2020
0226b18
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 20, 2020
12a9369
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 23, 2020
f49e245
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 24, 2020
a27dc72
little fixes
yekruglov Nov 25, 2020
f61bd38
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 25, 2020
db4a052
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 27, 2020
cbc61fa
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Nov 30, 2020
d1ee8e3
comments update
yekruglov Nov 30, 2020
4efb186
Merge remote-tracking branch 'upstream/master' into ykruglov/reshape/…
yekruglov Dec 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions model-optimizer/extensions/front/tf/TFSliceToSlice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Add, Equal
from extensions.ops.select import Select
from mo.front.common.replacement import FrontReplacementOp
Expand Down Expand Up @@ -74,4 +75,7 @@ def replace_sub_graph(self, graph: Graph, match: dict):
# out of select to end (2nd of slice)
select_node.out_port(0).connect(slice_node.in_port(2))

cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
select_node.in_port(2).get_connection().insert_node(cast)

node.out_port(0).get_connection().set_source(slice_node.out_port(0))
7 changes: 5 additions & 2 deletions model-optimizer/extensions/front/tf/TFSliceToSlice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
**regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
**regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}),
}


Expand Down Expand Up @@ -68,7 +69,8 @@ def test_slice_replacer_begin_with_2_inputs(self):

*connect_front('equal:0', 'select:0'),

*connect_front('end_const:0', 'select:2'),
*connect_front('end_const:0', 'cast:0'),
*connect_front('cast:0', 'select:2'),
*connect_front('select:0', 'slice:2'),

*connect_front('slice:0', 'output'),
Expand Down Expand Up @@ -97,7 +99,8 @@ def test_slice_replacer(self):
*connect_front('int32_max:0', '1:select'),
*connect_front('minus_one:0', '1:equal'),
*connect_front('equal:0', '0:select'),
*connect_front('end_const:0', '2:select'),
*connect_front('end_const:0', '0:cast'),
*connect_front('cast:0', '2:select'),
*connect_front('select:0', '2:slice'),
*connect_front('slice:0', 'output'),
], nodes_with_edges_only=True)
Expand Down
156 changes: 81 additions & 75 deletions model-optimizer/extensions/middle/SliceConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,45 @@

import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.gather import Gather
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.clamp import Clamp
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.strided_slice import StridedSlice
from mo.utils.error import Error


def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):
if value < 0:
indices[ind] += shape[ind]
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape, axes, node_name):
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
# the value for 'ends' might be maximum/minimum possible value of int64. This
# value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
# supported by the StridedSlice layer
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
clamp = create_op_with_const_inputs(graph, Clamp, port_value_dict={1: np.iinfo(np.int32).min,
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
2: np.iinfo(np.int32).max},
op_attrs=dict(name=node_name + '/Clamp'))
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
clamp.in_port(0).connect(slice_border_port)
cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node()
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
cast.in_port(0).connect(clamp.out_port(0))
concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node()
for value_idx, port_idx in enumerate(axes):
concat.add_input_port(port_idx)
value = create_op_with_const_inputs(graph, Gather, port_value_dict={1: int64_array([value_idx]),
2: 0},
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
op_attrs={'name': node_name + '/Gather'})
cast.out_port(0).connect(value.in_port(0))
value.out_port(0).connect(concat.in_port(port_idx))
for port_idx in range(len(shape)):
if not concat.is_in_port_connected(port_idx):
concat.add_input_port(port_idx)
const = Const(graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node()
yekruglov marked this conversation as resolved.
Show resolved Hide resolved
const.out_port(0).connect(concat.in_port(port_idx))

return concat


class ConvertSlice(MiddleReplacementPattern):
Expand All @@ -36,80 +63,59 @@ class ConvertSlice(MiddleReplacementPattern):
"""
yekruglov marked this conversation as resolved.
Show resolved Hide resolved

enabled = True
op = "Slice"
force_clean_up = True
op = "Slice"
yekruglov marked this conversation as resolved.
Show resolved Hide resolved

def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
yekruglov marked this conversation as resolved.
Show resolved Hide resolved

def pattern(self):
return dict(
nodes=[
('slice', dict(kind='op', op='Slice'))
],
edges=[]
)

def replace_pattern(self, graph: Graph, match: dict):
node = match['slice']

input_shape = node.in_port(0).data.get_shape()
output_shape = node.out_port(0).data.get_shape()
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if starts is None or ends is None:
raise Error('The input with starts or end is not constant for node {}'.format(node.id))

# the value for 'ends' is usually maximum possible value of int64. This
# value must be converted to maximum of int32 because such big values do not fit into the int32 which is
# supported by the StridedSlice layer
ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value()
if axes is None:
raise Error('The input with axes is not constant for node {}'.format(node.id))
else:
axes = int64_array(list(range(starts.size)))

if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
if steps is None:
raise Error('The input with steps is not constant for node {}'.format(node.id))
else:
steps = np.ones([starts.size])

ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_begin = np.zeros(len(input_shape), dtype=np.int32)
ss_end = np.zeros(len(input_shape), dtype=np.int32)
ss_step = np.ones(len(input_shape), dtype=np.int32)

# prepare inputs and attributes for the StridedSlice layer
for i, axis in enumerate(axes):
if starts[i] != 0:
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Slice'):
node_name = node.soft_get('name', node.id)

input_shape = node.in_port(0).data.get_shape()
if node.is_in_port_connected(3):
axes = node.in_port(3).data.get_value().copy()
assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
for i, val in enumerate(axes):
axes[i] = get_canonical_axis_index(input_shape, val)
else:
axes = int64_array(range(len(input_shape)))

ss_begin = create_ss_interval_border(graph, node.in_port(1).get_source(), input_shape, axes, node_name)
ss_end = create_ss_interval_border(graph, node.in_port(2).get_source(), input_shape, axes, node_name)
node.in_port(1).disconnect()
node.in_port(2).disconnect()
rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])

if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
else:
steps = np.ones([axes.size])

ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
ss_step = np.ones(len(input_shape), dtype=np.int64)

for i, axis in enumerate(axes):
ss_begin_mask[axis] = 1
ss_begin[axis] = starts[i]

ss_end_mask[axis] = 1
ss_end[axis] = ends[i]

ss_step[axis] = steps[i]

slice_node_name = node.soft_get('name', node.id)

begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()

ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask)).create_node()
rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)])
node.in_port(0).get_connection().set_destination(ss.in_port(0))
begin_node.out_port(0).connect(ss.in_port(1))
end_node.out_port(0).connect(ss.in_port(2))
strides_node.out_port(0).connect(ss.in_port(3))
node.out_port(0).get_connection().set_source(ss.out_port(0))
ss_end_mask[axis] = 1
ss_step[axis] = steps[i]

ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()

ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
begin_mask=ss_begin_mask,
end_mask=ss_end_mask)).create_node()

node.in_port(0).get_connection().set_destination(ss.in_port(0))
ss.in_port(1).connect(ss_begin.out_port(0))
ss.in_port(2).connect(ss_end.out_port(0))
ss.in_port(3).connect(ss_strides.out_port(0))
node.out_port(0).get_connection().set_source(ss.out_port(0))

rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
Loading