Skip to content

Commit

Permalink
Corrected preserving node names, renamed attributes names, added test…
Browse files Browse the repository at this point in the history
…s fro slice_replacer onnx phase
  • Loading branch information
pavel-esir committed Jul 17, 2020
1 parent 9a65e3a commit 26b3843
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 158 deletions.
12 changes: 6 additions & 6 deletions model-optimizer/extensions/front/onnx/slice_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ class SliceFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
if get_onnx_opset_version(node) < 10:
axis = int64_array(onnx_attr(node, 'axes', 'ints', default=[]))
start = int64_array(onnx_attr(node, 'starts', 'ints', default=[]))
end = int64_array(onnx_attr(node, 'ends', 'ints', default=[]))
axes = int64_array(onnx_attr(node, 'axes', 'ints', default=[]))
starts = int64_array(onnx_attr(node, 'starts', 'ints', default=[]))
ends = int64_array(onnx_attr(node, 'ends', 'ints', default=[]))

attrs = {
'axis': axis if len(axis) != 0 else None,
'start': start if len(start) != 0 else None,
'end': end if len(end) != 0 else None,
'axes': axes if len(axes) != 0 else None,
'starts': starts if len(starts) != 0 else None,
'ends': ends if len(ends) != 0 else None,
}
AttributedSlice.update_node_stat(node, attrs)
else: # onnx_opset_version >= 10
Expand Down
24 changes: 10 additions & 14 deletions model-optimizer/extensions/front/onnx/slice_replacers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
limitations under the License.
"""

import numpy as np

from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.slice import Slice


Expand All @@ -29,18 +31,12 @@ class AttributedSliceToSliceReplacer(FrontReplacementOp):

def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
slice_name = node.soft_get('name', node.id)

slice_node = Slice(graph, {'name': node.id + '/slice_'}).create_node()
node.in_port(0).get_connection().set_destination(slice_node.in_port(0))
node.out_port(0).get_connection().set_source(slice_node.out_port(0))

start_node = Const(graph, {'value': node.start, 'name': node.id + '/start_const'}).create_node()
end_node = Const(graph, {'value': node.end, 'name': node.id + '/end_const'}).create_node()
axes = node.axes if node.has_valid('axes') else np.arange(len(node.starts), dtype=np.int32)

slice_node.in_port(1).get_connection().set_source(start_node.out_port(0))
slice_node.in_port(2).get_connection().set_source(end_node.out_port(0))
if node.has_valid('axis'):
axis_node = Const(graph, {'value': node.axis, 'name': node.id + '/axis_const'}).create_node()
# slice_node.add_input_port(3, skip_if_exist=True)
slice_node.in_port(3).get_connection().set_source(axis_node.out_port(0))
slice_node = create_op_with_const_inputs(graph, Slice, {1: node.starts, 2: node.ends, 3: axes})
rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

node.in_port(0).get_connection().set_destination(slice_node.in_port(0))
node.out_port(0).get_connection().set_source(slice_node.out_port(0))
63 changes: 63 additions & 0 deletions model-optimizer/extensions/front/onnx/slice_replacers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np

from extensions.front.onnx.slice_replacers import AttributedSliceToSliceReplacer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const

nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('attributed_slice', {'op': 'AttributedSlice', 'type': None,
# todo: add test for the case when does not have axes attribute
# 'start': np.array([0, 0]), 'end': np.array([1, -1]), 'axis': np.array([0, 1])}),
'starts': np.array([0, 0]), 'ends': np.array([1, -1])}),
**result(),

# nodes after replacement
**const('start', np.array([0, 0])),
**const('end', np.array([1, -1])),
**const('axis', np.array(np.array([0, 1]))),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
}


class SliceReplacerTest(unittest.TestCase):

def test_attributed_slice_replacer(self):
graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'attributed_slice', {'out': 0}),
('attributed_slice', 'output', {'out': 0}),
], nodes_with_edges_only=True)
graph.stage = 'front'

AttributedSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
('input', 'slice', {'out': 0}),

('start', 'slice', {'out': 0, 'in': 1}),
('end', 'slice', {'out': 0, 'in': 2}),
('axis', 'slice', {'out': 0, 'in': 3}),

('slice', 'output', {'out': 0}),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
40 changes: 22 additions & 18 deletions model-optimizer/extensions/front/tf/slice_replacers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import numpy as np

from extensions.ops.elementwise import Add, Equal
from extensions.ops.select import Select
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Graph
from mo.graph.graph import Graph, rename_nodes
from mo.ops.const import Const
from mo.ops.eltwise import Eltwise
from mo.ops.slice import Slice
from extensions.ops.select import Select


class TFSliceToSliceReplacer(FrontReplacementOp):
Expand All @@ -36,29 +36,33 @@ def replace_sub_graph(self, graph: Graph, match: dict):
on the second input while Slice has ends. This transformation was added to avoid multiple ifs in the future.
"""
node = match['op']
begin_node = node.in_node(1)
size_node = node.in_node(2)
slice_name = node.soft_get('name', node.id)
slice_node = Slice(graph).create_node()
rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

eq_node = Eltwise(graph, dict(operation='equal', name=node.id + '/equal')).create_node()
minus_one_node = Const(graph, dict(value=np.array(-1), name=node.id + '/minus_one')).create_node()
int32_max_node = Const(graph, dict(value=np.iinfo(np.int32).max, name=node.id + '/int32_max')).create_node()
select_node = Select(graph, dict(name=node.id + '/select')).create_node()
eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node()
minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node()
int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node()
select_node = Select(graph, {'name': slice_name + '/select'}).create_node()

# node to convert sizes to ends
sum_node = Eltwise(graph, dict(operation='sum', name=node.id + '/end_const')).create_node()
slice_node = Slice(graph, dict(name=node.id + '/slice_')).create_node()
sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node()

# reconnect input from tfslice to slice
node.in_port(0).get_connection().set_destination(slice_node.in_port(0))
# connect begin of tfslice to start of slice
node.in_port(1).get_connection().set_destination(slice_node.in_port(1))
node.in_port(0).get_source().connect(slice_node.in_port(0))
node.in_port(0).disconnect()
# reconnect begin of tfslice to start of slice
node.in_port(1).get_source().connect(slice_node.in_port(1))
node.in_port(1).disconnect()

# (size -> ends) connect begins and sizes to sum to evaluate ends for Slice
begin_node.out_port(0).connect(sum_node.in_port(0))
node.in_port(2).get_connection().set_destination(sum_node.in_port(1))
# (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice
# connects begins to slice
slice_node.in_port(1).get_source().connect(sum_node.in_port(0))
node.in_port(2).get_source().connect(sum_node.in_port(1))
node.in_port(2).disconnect()

# if size[i] == -1 when take int32_max as end[i]
size_node.out_port(0).connect(eq_node.in_port(0))
sum_node.in_port(1).get_source().connect(eq_node.in_port(0))
minus_one_node.out_port(0).connect(eq_node.in_port(1))
# from equal to 0 port of select
eq_node.out_port(0).connect(select_node.in_port(0))
Expand Down
40 changes: 38 additions & 2 deletions model-optimizer/extensions/front/tf/slice_replacers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,56 @@
**regular_op_with_empty_data('tfslice', {'op': 'TFSlice', 'type': None}),
**const('begin', np.array(0)),
**const('size', np.array([-1])),
**regular_op_with_empty_data('john_doe', {'op': 'Sum', 'type': None}),
**result(),

# nodes after replacement
**const('minus_one', np.array(-1)),
**const('int32_max', np.array(np.iinfo(np.int32).max)),
**regular_op_with_empty_data('end_const', {'op': 'Add', 'type': 'Eltwise'}),
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Eltwise'}),
**regular_op_with_empty_data('end_const', {'op': 'Add', 'type': 'Add'}),
**regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
**regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
**regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
}


class SliceReplacerTest(unittest.TestCase):

def test_slice_replacer_begin_with_2_inputs(self):
graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'tfslice', {'out': 0}),
('begin', 'tfslice', {'out': 0, 'in': 1}),
('begin', 'john_doe', {'out': 0, 'in': 0}),
('size', 'tfslice', {'out': 0, 'in': 2}),
('tfslice', 'output', {'out': 0}),
], nodes_with_edges_only=True)
graph.stage = 'front'

TFSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
('input', 'slice', {'out': 0}),
('begin', 'slice', {'out': 0, 'in': 1}),
('begin', 'john_doe', {'out': 0, 'in': 1}),

('begin', 'end_const', {'out': 0, 'in': 0}),
('size', 'end_const', {'out': 0, 'in': 1}),
('size', 'equal', {'out': 0, 'in': 0}),

('int32_max', 'select', {'out': 0, 'in': 1}),
('minus_one', 'equal', {'out': 0, 'in': 1}),

('equal', 'select', {'out': 0, 'in': 0}),

('end_const', 'select', {'out': 0, 'in': 2}),
('select', 'slice', {'out': 0, 'in': 2}),

('slice', 'output', {'out': 0}),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_slice_replacer(self):
graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'tfslice', {'out': 0}),
Expand Down
6 changes: 3 additions & 3 deletions model-optimizer/extensions/middle/SliceConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def replace_pattern(self, graph: Graph, match: dict):
ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
ss_begin = np.zeros(len(input_shape), dtype=np.int32)
ss_end = np.zeros(len(input_shape), dtype=np.int32)
ss_steps = np.ones(len(input_shape), dtype=np.int32)
ss_step = np.ones(len(input_shape), dtype=np.int32)

# prepare inputs and attributes for the StridedSlice layer
for i, axis in enumerate(axes):
Expand All @@ -94,13 +94,13 @@ def replace_pattern(self, graph: Graph, match: dict):
ss_end_mask[axis] = 1
ss_end[axis] = ends[i]

ss_steps[axis] = steps[i]
ss_step[axis] = steps[i]

slice_node_name = node.soft_get('name', node.id)

begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
strides_node = Const(graph, {'value': ss_steps, 'name': slice_node_name + '/stride'}).create_node()
strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()

ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ConvertSliceTests(unittest.TestCase):
'steps': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
'steps_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Slice layer
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'format': 'onnx', 'end': None, 'name': 'slice_node'},
'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
# Output operation
'output_op': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
Expand Down
28 changes: 14 additions & 14 deletions model-optimizer/mo/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, graph: Graph, attrs: dict):
}, attrs)

def supported_attrs(self):
return ['axis', 'start', 'end']
return ['axes', 'starts', 'ends']


class CaffeSlice(Op):
Expand Down Expand Up @@ -122,11 +122,11 @@ class Slice(Op):
op = 'Slice'
enabled = False

def __init__(self, graph: Graph, attrs: dict):
def __init__(self, graph: Graph, attrs: dict = None):
super().__init__(graph, {
'type': None,
'op': 'Slice',
'in_ports_count': 4,
'in_ports_count': 5,
'out_ports_count': 1,
'infer': __class__.infer
}, attrs)
Expand All @@ -136,29 +136,29 @@ def infer(node: Node):
input_value = node.in_port(0).data.get_value()
input_shape = node.in_port(0).data.get_shape()

start = node.in_port(1).data.get_value()
end = node.in_port(2).data.get_value()
if start is None or end is None:
raise Error('The non-constant start/end values for Slice operation "{}" is not supported'.format(node.name))
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if starts is None or ends is None:
raise Error('The non-constant start/end values for Slice operation "{}" are not supported'.format(node.name))

if node.is_in_port_connected(3):
axis = node.in_port(3).data.get_value()
if axis is None:
raise Error('The non-constant axis values for Slice operation "{}" is not supported'.format(node.name))
axes = node.in_port(3).data.get_value()
if axes is None:
raise Error('The non-constant axes values for Slice operation "{}" is not supported'.format(node.name))
else:
axis = [x for x in range(len(start))]
axes = [x for x in range(len(starts))]

if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
if steps is None:
raise Error('The non-constant steps values for Slice operation "{}" is not supported'.format(node.name))
else:
steps = np.ones(start.size, dtype=np.int64)
steps = np.ones(len(starts), dtype=np.int64)

slice_idx = [slice(0, in_shape, 1) for in_shape in input_shape]
for i in range(len(axis)):
for i in range(len(axes)):
# Ranged for output value for specified axis
slice_idx[axis[i]] = slice(start[i], end[i], steps[i])
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])

if input_value is None:
output_shape = get_shape_after_slice(input_shape, slice_idx)
Expand Down
Loading

0 comments on commit 26b3843

Please sign in to comment.