diff --git a/model-optimizer/extensions/front/tf/TFSliceToSlice.py b/model-optimizer/extensions/front/tf/TFSliceToSlice.py index 8c03ca7376a1d6..62e59878452249 100644 --- a/model-optimizer/extensions/front/tf/TFSliceToSlice.py +++ b/model-optimizer/extensions/front/tf/TFSliceToSlice.py @@ -16,6 +16,7 @@ import numpy as np +from extensions.ops.Cast import Cast from extensions.ops.elementwise import Add, Equal from extensions.ops.select import Select from mo.front.common.replacement import FrontReplacementOp @@ -74,4 +75,7 @@ def replace_sub_graph(self, graph: Graph, match: dict): # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) + cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() + select_node.in_port(2).get_connection().insert_node(cast) + node.out_port(0).get_connection().set_source(slice_node.out_port(0)) diff --git a/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py b/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py index 14be81eb3c43cd..2919a71bfc47b0 100644 --- a/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py +++ b/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py @@ -37,6 +37,7 @@ **regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}), **regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}), **regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}), + **regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}), } @@ -68,7 +69,8 @@ def test_slice_replacer_begin_with_2_inputs(self): *connect_front('equal:0', 'select:0'), - *connect_front('end_const:0', 'select:2'), + *connect_front('end_const:0', 'cast:0'), + *connect_front('cast:0', 'select:2'), *connect_front('select:0', 'slice:2'), *connect_front('slice:0', 'output'), @@ -97,7 +99,8 @@ def test_slice_replacer(self): *connect_front('int32_max:0', '1:select'), *connect_front('minus_one:0', '1:equal'), *connect_front('equal:0', '0:select'), - *connect_front('end_const:0', '2:select'), + *connect_front('end_const:0', '0:cast'), + *connect_front('cast:0', '2:select'), *connect_front('select:0', '2:slice'), *connect_front('slice:0', 'output'), ], nodes_with_edges_only=True) diff --git a/model-optimizer/extensions/middle/SliceConverter.py b/model-optimizer/extensions/middle/SliceConverter.py index 5ed91a756ff11b..5a9df90917ea90 100644 --- a/model-optimizer/extensions/middle/SliceConverter.py +++ b/model-optimizer/extensions/middle/SliceConverter.py @@ -16,100 +16,117 @@ import numpy as np +from extensions.ops.Cast import Cast +from extensions.ops.gather import Gather +from mo.front.caffe.extractors.utils import get_canonical_axis_index from mo.front.common.partial_infer.utils import int64_array +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph, rename_nodes +from mo.graph.port import Port from mo.middle.replacement import MiddleReplacementPattern +from mo.ops.clamp import Clamp +from mo.ops.concat import Concat from mo.ops.const import Const from mo.ops.strided_slice import StridedSlice -from mo.utils.error import Error -def convert_negative_indices(indices: np.array, shape: np.array): - for ind, value in enumerate(indices): - if value < 0: - indices[ind] += shape[ind] +def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): + """ + This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" + + :param graph: graph to operate on. + :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. + :param shape: input shape of the Slice + :param axes: axes that "starts" and "ends" apply to + :param node_name: Slice node name + :return: Concat node that forms "begin"/"end" values for the StridedSlice + """ + # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This + # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is + # supported by the StridedSlice layer + clamp = create_op_with_const_inputs( + graph, Clamp, port_value_dict={1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max}, + op_attrs=dict(name=node_name + '/Clamp')) + clamp.in_port(0).connect(slice_border_port) + # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created + # here to prevent type errors in Concat node + cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() + cast.in_port(0).connect(clamp.out_port(0)) + concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() + for value_idx, port_idx in enumerate(axes): + concat.add_input_port(port_idx) + # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct + # Concat input port + value = create_op_with_const_inputs( + graph, Gather, port_value_dict={1: int64_array([value_idx]), 2: int64_array(0)}, + op_attrs={'name': node_name + '/Gather'}) + cast.out_port(0).connect(value.in_port(0)) + value.out_port(0).connect(concat.in_port(port_idx)) + for port_idx in range(len(shape)): + if not concat.is_in_port_connected(port_idx): + concat.add_input_port(port_idx) + # This border value would be ignored in StridedSlice because of the begin_mask\end_mask + const = Const(graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() + const.out_port(0).connect(concat.in_port(port_idx)) + + return concat class ConvertSlice(MiddleReplacementPattern): """ - This class converts Slice operation to StridedSlice + This class converts a Slice operation to StridedSlice in reshape-able way by parsing the 'starts' and 'ends' + parameters based on the 'axes' parameter """ enabled = True - op = "Slice" force_clean_up = True - def run_after(self): - from extensions.middle.pass_separator import MiddleStart - return [MiddleStart] - - def pattern(self): - return dict( - nodes=[ - ('slice', dict(kind='op', op='Slice')) - ], - edges=[] - ) - - def replace_pattern(self, graph: Graph, match: dict): - node = match['slice'] - - input_shape = node.in_port(0).data.get_shape() - output_shape = node.out_port(0).data.get_shape() - starts = node.in_port(1).data.get_value() - ends = node.in_port(2).data.get_value() - if starts is None or ends is None: - raise Error('The input with starts or end is not constant for node {}'.format(node.id)) - - # the value for 'ends' is usually maximum possible value of int64. This - # value must be converted to maximum of int32 because such big values do not fit into the int32 which is - # supported by the StridedSlice layer - ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max) - if node.is_in_port_connected(3): - axes = node.in_port(3).data.get_value() - if axes is None: - raise Error('The input with axes is not constant for node {}'.format(node.id)) - else: - axes = int64_array(list(range(starts.size))) - - if node.is_in_port_connected(4): - steps = node.in_port(4).data.get_value() - if steps is None: - raise Error('The input with steps is not constant for node {}'.format(node.id)) - else: - steps = np.ones([starts.size]) - - ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32) - ss_end_mask = np.zeros(len(input_shape), dtype=np.int32) - ss_begin = np.zeros(len(input_shape), dtype=np.int32) - ss_end = np.zeros(len(input_shape), dtype=np.int32) - ss_step = np.ones(len(input_shape), dtype=np.int32) - - # prepare inputs and attributes for the StridedSlice layer - for i, axis in enumerate(axes): - if starts[i] != 0: + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(op='Slice'): + node_name = node.soft_get('name', node.id) + + input_shape = node.in_port(0).data.get_shape() + if node.is_in_port_connected(3): + axes = node.in_port(3).data.get_value().copy() + assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name) + for i, val in enumerate(axes): + axes[i] = get_canonical_axis_index(input_shape, val) + else: + axes = int64_array(range(len(input_shape))) + + ss_begin = create_ss_interval_border(graph, node.in_port(1).get_source(), input_shape, axes, node_name) + ss_end = create_ss_interval_border(graph, node.in_port(2).get_source(), input_shape, axes, node_name) + node.in_port(1).disconnect() + node.in_port(2).disconnect() + rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')]) + + if node.is_in_port_connected(4): + steps = node.in_port(4).data.get_value() + assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name) + else: + steps = np.ones([axes.size]) + + ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64) + ss_end_mask = np.zeros(len(input_shape), dtype=np.int64) + ss_step = np.ones(len(input_shape), dtype=np.int64) + + for i, axis in enumerate(axes): ss_begin_mask[axis] = 1 - ss_begin[axis] = starts[i] - - ss_end_mask[axis] = 1 - ss_end[axis] = ends[i] - - ss_step[axis] = steps[i] - - slice_node_name = node.soft_get('name', node.id) - - begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node() - end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node() - strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node() - - ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32), - shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32), - ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32), - begin_mask=ss_begin_mask, - end_mask=ss_end_mask)).create_node() - rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)]) - node.in_port(0).get_connection().set_destination(ss.in_port(0)) - begin_node.out_port(0).connect(ss.in_port(1)) - end_node.out_port(0).connect(ss.in_port(2)) - strides_node.out_port(0).connect(ss.in_port(3)) - node.out_port(0).get_connection().set_source(ss.out_port(0)) + ss_end_mask[axis] = 1 + ss_step[axis] = steps[i] + + ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node() + + ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64), + shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64), + ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64), + begin_mask=ss_begin_mask, + end_mask=ss_end_mask)).create_node() + + node.in_port(0).get_connection().set_destination(ss.in_port(0)) + ss.in_port(1).connect(ss_begin.out_port(0)) + ss.in_port(2).connect(ss_end.out_port(0)) + ss.in_port(3).connect(ss_strides.out_port(0)) + node.out_port(0).get_connection().set_source(ss.out_port(0)) + + rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)]) diff --git a/model-optimizer/extensions/middle/SliceConverter_test.py b/model-optimizer/extensions/middle/SliceConverter_test.py index 92b118d0605b20..63380c28292248 100644 --- a/model-optimizer/extensions/middle/SliceConverter_test.py +++ b/model-optimizer/extensions/middle/SliceConverter_test.py @@ -20,304 +20,377 @@ from extensions.middle.SliceConverter import ConvertSlice from mo.front.common.partial_infer.utils import int64_array -from mo.graph.graph import Node -from mo.ops.slice import Slice from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph +from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \ + regular_op_with_empty_data, result, connect, connect_data nodes_attributes = { - # input data - 'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'placeholder_2': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'placeholder_3': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - # Slice layer - 'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'}, - 'slice_data': {'value': None, 'shape': None, 'kind': 'data'}, - # Output operation - 'output_op': {'type': 'Const', 'value': None, 'kind': 'op', 'op': 'Const'}, - 'output_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'op_output': { 'kind': 'op', 'op': 'Result'}, - # StridedSlice layer - 'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None} + **regular_op_with_shaped_data('input', [2, 3, 300, 300], {'type': 'Parameter', 'op': 'Parameter'}), + **regular_op_with_empty_data('starts', {'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ends', {'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('axes', {'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('steps', {'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}), + + **regular_op_with_empty_data('ss_begin_cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}), + **regular_op_with_empty_data('ss_begin_clamp', {'op': 'Clamp', 'type': None}), + **regular_op_with_empty_data('ss_begin_clamp_min', {'value': np.iinfo(np.int32).min, 'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ss_begin_clamp_max', {'value': np.iinfo(np.int32).max, 'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ss_begin_gather_0', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_begin_gather_0_idx', int64_array([0])), + **regular_op_with_shaped_data('ss_begin_gather_0_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_begin_gather_1', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_begin_gather_1_idx', int64_array([1])), + **regular_op_with_shaped_data('ss_begin_gather_1_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_begin_gather_2', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_begin_gather_2_idx', int64_array([2])), + **regular_op_with_shaped_data('ss_begin_gather_2_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_begin_gather_3', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_begin_gather_3_idx', int64_array([3])), + **regular_op_with_shaped_data('ss_begin_gather_3_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_begin_const_0', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_begin_const_1', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_begin_const_2', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_begin_const_3', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_begin_concat', {'op': 'Concat', 'type': 'Concat'}), + + **regular_op_with_empty_data('ss_end_cast', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}), + **regular_op_with_empty_data('ss_end_clamp', {'op': 'Clamp', 'type': None}), + **regular_op_with_empty_data('ss_end_clamp_min', {'value': np.iinfo(np.int32).min, 'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ss_end_clamp_max', {'value': np.iinfo(np.int32).max, 'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ss_end_gather_0', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_end_gather_0_idx', int64_array([0])), + **regular_op_with_shaped_data('ss_end_gather_0_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_end_gather_1', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_end_gather_1_idx', int64_array([1])), + **regular_op_with_shaped_data('ss_end_gather_1_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_end_gather_2', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_end_gather_2_idx', int64_array([2])), + **regular_op_with_shaped_data('ss_end_gather_2_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_end_gather_3', {'op': 'Gather', 'type': 'Gather'}), + **valued_const_with_data('ss_end_gather_3_idx', int64_array([3])), + **regular_op_with_shaped_data('ss_end_gather_3_axis', [], {'op': 'Const', 'type': 'Const', 'value': [0]}), + **regular_op_with_empty_data('ss_end_const_0', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_end_const_1', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_end_const_2', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_end_const_3', {'op': 'Const', 'type': 'Const', 'value': int64_array([0])}), + **regular_op_with_empty_data('ss_end_concat', {'op': 'Concat', 'type': 'Concat'}), + + **regular_op_with_empty_data('ss_strides', {'op': 'Const', 'type': 'Const'}), + **regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice', + 'new_axis_mask': np.zeros(4, dtype=np.int64), + 'shrink_axis_mask': np.zeros(4, dtype=np.int64), + 'ellipsis_mask': np.zeros(4, dtype=np.int64)}), + **result('result') } +pattern_graph = [ + *connect('input:0', '0:slice'), + *connect('starts:0', '1:slice'), + *connect('ends:0', '2:slice'), + *connect('axes:0', '3:slice'), + *connect('steps:0', '4:slice'), + *connect('slice:0', '0:result') +] + +pattern_ref_graph = [ + *connect('input:0', '0:ss'), + *connect('starts:0', '0:ss_begin_clamp'), + *connect('ss_begin_clamp:0', '0:ss_begin_cast'), + *connect('ss_begin_clamp_min:0', '1:ss_begin_clamp'), + *connect('ss_begin_clamp_max:0', '2:ss_begin_clamp'), + *connect('ss_begin_concat:0', '1:ss'), + *connect('ends:0', '0:ss_end_clamp'), + *connect('ss_end_clamp:0', '0:ss_end_cast'), + *connect('ss_end_clamp_min:0', '1:ss_end_clamp'), + *connect('ss_end_clamp_max:0', '2:ss_end_clamp'), + *connect('ss_end_concat:0', '2:ss'), + *connect('ss_strides:0', '3:ss'), + *connect('ss:0', '0:result'), + + *connect('ss_begin_gather_0_idx:0', '1:ss_begin_gather_0'), + *connect('ss_begin_gather_0_axis:0', '2:ss_begin_gather_0'), + *connect('ss_begin_gather_1_idx:0', '1:ss_begin_gather_1'), + *connect('ss_begin_gather_1_axis:0', '2:ss_begin_gather_1'), + *connect('ss_begin_gather_2_idx:0', '1:ss_begin_gather_2'), + *connect('ss_begin_gather_2_axis:0', '2:ss_begin_gather_2'), + *connect('ss_begin_gather_3_idx:0', '1:ss_begin_gather_3'), + *connect('ss_begin_gather_3_axis:0', '2:ss_begin_gather_3'), + + *connect('ss_end_gather_0_idx:0', '1:ss_end_gather_0'), + *connect('ss_end_gather_0_axis:0', '2:ss_end_gather_0'), + *connect('ss_end_gather_1_idx:0', '1:ss_end_gather_1'), + *connect('ss_end_gather_1_axis:0', '2:ss_end_gather_1'), + *connect('ss_end_gather_2_idx:0', '1:ss_end_gather_2'), + *connect('ss_end_gather_2_axis:0', '2:ss_end_gather_2'), + *connect('ss_end_gather_3_idx:0', '1:ss_end_gather_3'), + *connect('ss_end_gather_3_axis:0', '2:ss_end_gather_3'), +] + class ConvertSliceTests(unittest.TestCase): - nodes_attributes = { - # input data - 'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - # Slice layer inputs - 'starts': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'starts_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'ends': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'ends_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'strides': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'strides_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'axes': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'axes_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'steps': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'steps_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - # Slice layer - 'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'}, - 'slice_data': {'value': None, 'shape': None, 'kind': 'data'}, - # Output operation - 'output_op': {'type': 'Const', 'kind': 'op', 'op': 'Const'}, - 'output_data': {'shape': None, 'kind': 'data', 'data_type': None}, - 'op_output': {'kind': 'op', 'op': 'Result'}, - # StridedSlice layer - 'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None} - } - - def test_slice_all_params(self): - input_shape = int64_array([5, 10, 20]) - starts_value = int64_array([4, 2]) - ends_value = int64_array([15, 8]) - axes_value = int64_array([2, 1]) - steps_value = int64_array([1, 1]) - - masks_value = np.zeros([len(input_shape)], dtype=np.int64) - graph = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'slice', {'in': 2}), - ('axes', 'axes_data'), - ('axes_data', 'slice', {'in': 3}), - ('steps', 'steps_data'), - ('steps_data', 'slice', {'in': 4}), - ('slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'starts': {'shape': starts_value.shape, 'value': starts_value}, - 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, - 'ends': {'shape': ends_value.shape, 'value': ends_value}, - 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, - 'steps': {'shape': steps_value.shape, 'value': steps_value}, - 'steps_data': {'shape': steps_value.shape, 'value': steps_value}, - 'axes': {'shape': axes_value.shape, 'value': axes_value}, - 'axes_data': {'shape': axes_value.shape, 'value': axes_value}, - }, nodes_with_edges_only=True - ) - slice_node = Node(graph, 'slice') - Slice.infer(slice_node) - - pattern = ConvertSlice() - pattern.find_and_replace_pattern(graph) - - ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) - assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' - - graph_ref = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'strided_slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'strided_slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'strided_slice', {'in': 2}), - ('strides', 'strides_data'), - ('strides_data', 'strided_slice', {'in': 3}), - ('strided_slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, - 'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]), - 'end_mask': int64_array([0, 1, 1])}, - 'slice_data': {'shape': int64_array([5, 6, 11])} - }, nodes_with_edges_only=True - ) - (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) + + def test_convert_slice_to_strided_slice_one_axis(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_graph, + update_attributes={ + 'starts': {'value': int64_array([0]), 'shape': [1]}, + 'ends': {'value': int64_array([1]), 'shape': [1]}, + 'axes': {'value': int64_array([0]), 'shape': [1]}, + 'axes_d': {'value': int64_array([0]), 'shape': [1]}, + 'steps': {'value': int64_array([1]), 'shape': [1]}, + 'steps_d': {'value': int64_array([1]), 'shape': [1]} + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '0:ss_begin_concat'), + *connect('ss_begin_const_1:0', '1:ss_begin_concat'), + *connect('ss_begin_const_2:0', '2:ss_begin_concat'), + *connect('ss_begin_const_3:0', '3:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '0:ss_end_concat'), + *connect('ss_end_const_1:0', '1:ss_end_concat'), + *connect('ss_end_const_2:0', '2:ss_end_concat'), + *connect('ss_end_const_3:0', '3:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0]), 'shape': [1]}, + 'ends': {'value': int64_array([1]), 'shape': [1]}, + 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([1, 0, 0, 0]), 'end_mask': int64_array([1, 0, 0, 0])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp) - def test_no_steps_no_axes(self): - input_shape = int64_array([5, 10, 20]) - starts_value = int64_array([3, 2, 7]) - ends_value = int64_array([5, 8, 15]) - steps_value = int64_array([1, 1, 1]) - masks_value = np.zeros([len(input_shape)], dtype=np.int64) - graph = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'slice', {'in': 2}), - ('slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'starts': {'shape': starts_value.shape, 'value': starts_value}, - 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, - 'ends': {'shape': ends_value.shape, 'value': ends_value}, - 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, - }, nodes_with_edges_only=True - ) - slice_node = Node(graph, 'slice') - Slice.infer(slice_node) - - pattern = ConvertSlice() - pattern.find_and_replace_pattern(graph) - - ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) - assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' - - graph_ref = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'strided_slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'strided_slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'strided_slice', {'in': 2}), - ('strides', 'strides_data'), - ('strides_data', 'strided_slice', {'in': 3}), - ('strided_slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, - 'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]), - 'end_mask': np.ones([3])}, - 'slice_data': {'shape': int64_array([2, 6, 8])} - }, nodes_with_edges_only=True - ) - (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) + def test_convert_slice_to_strided_slice_one_axis_steps_is_2(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_graph, + update_attributes={ + 'starts': {'value': int64_array([0]), 'shape': [1]}, + 'ends': {'value': int64_array([150]), 'shape': [1]}, + 'axes': {'value': int64_array([2]), 'shape': [1]}, + 'axes_d': {'value': int64_array([2]), 'shape': [1]}, + 'steps': {'value': int64_array([2]), 'shape': [1]}, + 'steps_d': {'value': int64_array([2]), 'shape': [1]} + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '2:ss_begin_concat'), + *connect('ss_begin_const_0:0', '0:ss_begin_concat'), + *connect('ss_begin_const_1:0', '1:ss_begin_concat'), + *connect('ss_begin_const_3:0', '3:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '2:ss_end_concat'), + *connect('ss_end_const_0:0', '0:ss_end_concat'), + *connect('ss_end_const_1:0', '1:ss_end_concat'), + *connect('ss_end_const_3:0', '3:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0]), 'shape': [1]}, + 'ends': {'value': int64_array([150]), 'shape': [1]}, + 'ss_strides': {'value': int64_array([1, 1, 2, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([0, 0, 1, 0]), 'end_mask': int64_array([0, 0, 1, 0])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp) - def test_no_axes(self): - input_shape = int64_array([5, 10, 20]) - starts_value = int64_array([3, 2, 7]) - ends_value = int64_array([5, 8, 15]) - steps_value = int64_array([2, 3, 1]) - masks_value = np.zeros([len(input_shape)], dtype=np.int64) - graph = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'slice', {'in': 2}), - ('steps', 'steps_data'), - ('steps_data', 'slice', {'in': 4}), - ('slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'starts': {'shape': starts_value.shape, 'value': starts_value}, - 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, - 'ends': {'shape': ends_value.shape, 'value': ends_value}, - 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, - 'steps': {'shape': steps_value.shape, 'value': steps_value}, - 'steps_data': {'shape': steps_value.shape, 'value': steps_value}, - }, nodes_with_edges_only=True - ) - slice_node = Node(graph, 'slice') - Slice.infer(slice_node) - - pattern = ConvertSlice() - pattern.find_and_replace_pattern(graph) - - ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) - assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' - - graph_ref = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'strided_slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'strided_slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'strided_slice', {'in': 2}), - ('strides', 'strides_data'), - ('strides_data', 'strided_slice', {'in': 3}), - ('strided_slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, - 'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]), - 'end_mask': np.ones([3])}, - 'slice_data': {'shape': int64_array([1, 2, 8])} - }, nodes_with_edges_only=True - ) - (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) + def test_convert_slice_to_strided_slice_two_axes(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_graph, + update_attributes={ + 'starts': {'value': int64_array([0, 0]), 'shape': [2]}, + 'ends': {'value': int64_array([150, 150]), 'shape': [2]}, + 'axes': {'value': int64_array([2, 3]), 'shape': [2]}, + 'axes_d': {'value': int64_array([2, 3]), 'shape': [2]}, + 'steps': {'value': int64_array([1, 1]), 'shape': [2]}, + 'steps_d': {'value': int64_array([1, 1]), 'shape': [2]} + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '2:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'), + *connect('ss_begin_gather_1:0', '3:ss_begin_concat'), + *connect('ss_begin_const_0:0', '0:ss_begin_concat'), + *connect('ss_begin_const_1:0', '1:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '2:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_1'), + *connect('ss_end_gather_1:0', '3:ss_end_concat'), + *connect('ss_end_const_0:0', '0:ss_end_concat'), + *connect('ss_end_const_1:0', '1:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0, 0]), 'shape': [2]}, + 'ends': {'value': int64_array([150, 150]), 'shape': [2]}, + 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([0, 0, 1, 1]), 'end_mask': int64_array([0, 0, 1, 1])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp) - def test_no_steps(self): - input_shape = int64_array([5, 10, 20]) - starts_value = int64_array([4, 2]) - ends_value = int64_array([15, 8]) - axes_value = int64_array([2, 1]) - masks_value = np.zeros([len(input_shape)], dtype=np.int64) - graph = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'slice', {'in': 2}), - ('axes', 'axes_data'), - ('axes_data', 'slice', {'in': 3}), - ('slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'starts': {'shape': starts_value.shape, 'value': starts_value}, - 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, - 'ends': {'shape': ends_value.shape, 'value': ends_value}, - 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, - 'axes': {'shape': axes_value.shape, 'value': axes_value}, - 'axes_data': {'shape': axes_value.shape, 'value': axes_value}, - }, nodes_with_edges_only=True - ) - slice_node = Node(graph, 'slice') - Slice.infer(slice_node) - - pattern = ConvertSlice() - pattern.find_and_replace_pattern(graph) - - ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) - assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' - - graph_ref = build_graph(self.nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'strided_slice', {'in': 0}), - ('starts', 'starts_data'), - ('starts_data', 'strided_slice', {'in': 1}), - ('ends', 'ends_data'), - ('ends_data', 'strided_slice', {'in': 2}), - ('strides', 'strides_data'), - ('strides_data', 'strided_slice', {'in': 3}), - ('strided_slice', 'slice_data'), - ('slice_data', 'output_op'), - ('output_op', 'output_data'), - ('output_data', 'op_output') - ], - {'placeholder_1_data': {'shape': input_shape}, - 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, - 'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]), - 'end_mask': int64_array([0, 1, 1])}, - 'slice_data': {'shape': int64_array([5, 6, 11])} - }, nodes_with_edges_only=True - ) - (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) + def test_convert_slice_to_strided_slice_three_axes(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_graph, + update_attributes={ + 'starts': {'value': int64_array([0, 0, 0]), 'shape': [3]}, + 'ends': {'value': int64_array([2, 150, 150]), 'shape': [3]}, + 'axes': {'value': int64_array([1, 2, 3]), 'shape': [3]}, + 'axes_d': {'value': int64_array([1, 2, 3]), 'shape': [3]}, + 'steps': {'value': int64_array([1, 1, 1]), 'shape': [3]}, + 'steps_d': {'value': int64_array([1, 1, 1]), 'shape': [3]} + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '1:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'), + *connect('ss_begin_gather_1:0', '2:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'), + *connect('ss_begin_gather_2:0', '3:ss_begin_concat'), + *connect('ss_begin_const_0:0', '0:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '1:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_1'), + *connect('ss_end_gather_1:0', '2:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_2'), + *connect('ss_end_gather_2:0', '3:ss_end_concat'), + *connect('ss_end_const_0:0', '0:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0, 0, 0]), 'shape': [3]}, + 'ends': {'value': int64_array([2, 150, 150]), 'shape': [3]}, + 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([0, 1, 1, 1]), 'end_mask': int64_array([0, 1, 1, 1])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_convert_slice_to_strided_slice_not_sorted_axes(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_graph, + update_attributes={ + 'starts': {'value': int64_array([0, 1, 1, 0]), 'shape': [4]}, + 'ends': {'value': int64_array([1, 150, 150, 2]), 'shape': [4]}, + 'axes': {'value': int64_array([0, 2, 3, 1]), 'shape': [4]}, + 'axes_d': {'value': int64_array([0, 2, 3, 1]), 'shape': [4]}, + 'steps': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'steps_d': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]} + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '0:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'), + *connect('ss_begin_gather_1:0', '2:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'), + *connect('ss_begin_gather_2:0', '3:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_3'), + *connect('ss_begin_gather_3:0', '1:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '0:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_1'), + *connect('ss_end_gather_1:0', '2:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_2'), + *connect('ss_end_gather_2:0', '3:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_3'), + *connect('ss_end_gather_3:0', '1:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0, 1, 1, 0]), 'shape': [4]}, + 'ends': {'value': int64_array([1, 150, 150, 2]), 'shape': [4]}, + 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([1, 1, 1, 1]), 'end_mask': int64_array([1, 1, 1, 1])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_convert_slice_to_strided_slice_without_axes_and_steps(self): + graph = build_graph( + nodes_attrs=nodes_attributes, + edges=[ + *connect('input:0', '0:slice'), + *connect('starts:0', '1:slice'), + *connect('ends:0', '2:slice'), + *connect('slice:0', '0:result') + ], + update_attributes={ + 'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]}, + 'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]}, + }, + nodes_with_edges_only=True + ) + + ref_graph = build_graph( + nodes_attrs=nodes_attributes, + edges=pattern_ref_graph + [ + *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), + *connect('ss_begin_gather_0:0', '0:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'), + *connect('ss_begin_gather_1:0', '1:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'), + *connect('ss_begin_gather_2:0', '2:ss_begin_concat'), + *connect_data('ss_begin_cast:0', '0:ss_begin_gather_3'), + *connect('ss_begin_gather_3:0', '3:ss_begin_concat'), + + *connect('ss_end_cast:0', '0:ss_end_gather_0'), + *connect('ss_end_gather_0:0', '0:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_1'), + *connect('ss_end_gather_1:0', '1:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_2'), + *connect('ss_end_gather_2:0', '2:ss_end_concat'), + *connect_data('ss_end_cast:0', '0:ss_end_gather_3'), + *connect('ss_end_gather_3:0', '3:ss_end_concat'), + ], + update_attributes={ + 'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]}, + 'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]}, + 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, + 'ss': {'begin_mask': int64_array([1, 1, 1, 1]), 'end_mask': int64_array([1, 1, 1, 1])} + } + ) + ConvertSlice().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) \ No newline at end of file