diff --git a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md index 50920f6e4c0cfa..78b47d278187d7 100644 --- a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md +++ b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md @@ -158,7 +158,7 @@ Standard TensorFlow\* operations: | FloorDiv | No | | FusedBatchNorm | No | | Gather | No | -| GatherNd | Supported if it can be replaced with Gather | +| GatherNd | No | | GatherV2 | No | | Greater | No | | GreaterEqual | No | @@ -337,6 +337,7 @@ Standard ONNX\* operators: | Floor | No | | GRU | No | | Gather | No | +| GatherND | No | | GatherTree | No | | Gemm | No | | GlobalAveragePool | No | diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index d8e58a894c2c21..60dcace71c8152 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -258,6 +258,7 @@ extensions/front/onnx/expand_ext.py extensions/front/onnx/flatten_ext.py extensions/front/onnx/flattenONNX_to_reshape.py extensions/front/onnx/gather_ext.py +extensions/front/onnx/gathernd_ext.py extensions/front/onnx/gemm_ext.py extensions/front/onnx/group_norm_ext.py extensions/front/onnx/gru_ext.py @@ -382,6 +383,7 @@ extensions/front/tf/FlattenToReshape.py extensions/front/tf/floor_div_decomposition.py extensions/front/tf/floor_ext.py extensions/front/tf/gather_ext.py +extensions/front/tf/gathernd_ext.py extensions/front/tf/GatherTree_ext.py extensions/front/tf/GNMT_DynamicSequenceLengths.py extensions/front/tf/identity_ext.py @@ -617,7 +619,7 @@ extensions/ops/ExtractImagePatches.py extensions/ops/fake_output.py extensions/ops/fakequantize.py extensions/ops/gather.py -extensions/ops/GatherNd.py +extensions/ops/gathernd.py extensions/ops/GatherTree.py extensions/ops/gelu.py extensions/ops/grn.py diff --git a/model-optimizer/extensions/front/onnx/gathernd_ext.py b/model-optimizer/extensions/front/onnx/gathernd_ext.py new file mode 100644 index 00000000000000..34be3aa153ce5f --- /dev/null +++ b/model-optimizer/extensions/front/onnx/gathernd_ext.py @@ -0,0 +1,32 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.ops.gathernd import GatherND +from mo.front.extractor import FrontExtractorOp +from mo.front.onnx.extractors.utils import onnx_attr + + +class GatherNDFrontExtractor(FrontExtractorOp): + op = 'GatherND' + enabled = True + + @classmethod + def extract(cls, node): + attrs = { + 'batch_dims': onnx_attr(node, 'batch_dims', 'i', default=0) + } + GatherND.update_node_stat(node, attrs) + return cls.enabled diff --git a/model-optimizer/extensions/front/tf/gathernd_ext.py b/model-optimizer/extensions/front/tf/gathernd_ext.py new file mode 100644 index 00000000000000..24c1a44020443d --- /dev/null +++ b/model-optimizer/extensions/front/tf/gathernd_ext.py @@ -0,0 +1,30 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from extensions.ops.gathernd import GatherND +from mo.front.extractor import FrontExtractorOp + + +class GatherNDFrontExtractor(FrontExtractorOp): + op = 'GatherNd' + enabled = True + + @classmethod + def extract(cls, node): + attrs = { + 'batch_dims': 0, + } + GatherND.update_node_stat(node, attrs) + return cls.enabled diff --git a/model-optimizer/extensions/middle/GatherNdNormalizer.py b/model-optimizer/extensions/middle/GatherNdNormalizer.py index 469433b138ef5d..0a973ad29b9d77 100644 --- a/model-optimizer/extensions/middle/GatherNdNormalizer.py +++ b/model-optimizer/extensions/middle/GatherNdNormalizer.py @@ -25,11 +25,14 @@ from mo.ops.reshape import Reshape -class GatherNdNormalize(MiddleReplacementPattern): +class GatherNDNormalize(MiddleReplacementPattern): """ Hot fix for new speech-to-text model enabling while GatherND is not implemented in IE. - We can replace GatherNd to Reshape + Gather in case when GatherNd indices have just one + We can replace GatherND to Reshape + Gather in case when GatherND indices have just one meaningful dimension. + TODO: Investigate whether we must replace GatherND with Reshape + Gather always (due to performance benefits) + for this particular case or only if the plugin does not support GatherND. + And the best place for the transformation is nGraph so we need to move it. """ enabled = True force_clean_up = True @@ -44,7 +47,7 @@ def run_after(self): def pattern(self): return dict( - nodes=[('GatherNd', dict(kind='op', op='GatherNd'))], + nodes=[('GatherND', dict(kind='op', op='GatherND', batch_dims=0))], edges=[] ) @@ -67,7 +70,7 @@ def indices_check(indices: np.array, input_shape: tuple): return non_zero def replace_pattern(self, graph: Graph, match: dict): - gather = match['GatherNd'] + gather = match['GatherND'] gather_name = gather.soft_get('name', gather.id) input_shape = gather.in_node(0).shape indices = gather.in_node(1).value @@ -75,16 +78,16 @@ def replace_pattern(self, graph: Graph, match: dict): # We can't do such special pass without indices value return - # 0. All needed checks that we can replace GatherNd by Gather + # 0. All needed checks that we can replace GatherND by Gather gather_idx = self.indices_check(indices, input_shape) if gather_idx is None: - log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name)) + log.warning('Node {} with op=GatherND can\'t be normalized to op=Gather.'.format(gather_name)) return # 1. Add Reshape and connect new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:])) reshape = create_op_node_with_second_input(graph, Reshape, new_shape, - {'name': gather_name + '/Reshape_for_GatherNd/'}) + {'name': gather_name + '/Reshape_for_GatherND/'}) gather.in_port(0).get_connection().set_destination(reshape.in_port(0)) # 2. Change indices from Nd to 1d: diff --git a/model-optimizer/extensions/ops/GatherNd.py b/model-optimizer/extensions/ops/GatherNd.py deleted file mode 100644 index 219da66cfc43a8..00000000000000 --- a/model-optimizer/extensions/ops/GatherNd.py +++ /dev/null @@ -1,47 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import numpy as np - -from mo.graph.graph import Node, Graph -from mo.ops.op import Op - - -class GatherNd(Op): - op = 'GatherNd' - - def __init__(self, graph: Graph, attrs: dict): - mandatory_props = { - 'op': __class__.op, - 'infer': __class__.infer, - 'in_ports_count': 2, - 'out_ports_count': 1, - } - super().__init__(graph, mandatory_props, attrs) - - def supported_attrs(self): - return [] - - @staticmethod - def infer(node: Node): - input_node = node.in_node(0) - indices = node.in_node(1).value - - assert indices is not None - - output_shape = list(indices.shape[:-1]) + list(input_node.shape[indices.shape[-1]:]) - node.out_node().shape = np.array(output_shape, dtype=np.int64) - # TODO: implement constant path diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py new file mode 100644 index 00000000000000..ff69ce7918d31a --- /dev/null +++ b/model-optimizer/extensions/ops/gathernd.py @@ -0,0 +1,102 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node, Graph +from mo.ops.op import Op + + +class GatherND(Op): + op = 'GatherND' + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'type': self.op, + 'op': self.op, + 'version': 'opset5', + 'infer': self.infer, + 'in_ports_count': 2, + 'out_ports_count': 1, + 'batch_dims': 0 + } + super().__init__(graph, mandatory_props, attrs) + + def backend_attrs(self): + return ['batch_dims'] + + @staticmethod + def infer(node: Node): + node_name = node.soft_get('name', node.id) + connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] + assert len(connected_in_ports) == 2, \ + "Incorrect number of inputs for {} node".format(node_name) + + data_shape = node.in_port(0).data.get_shape() + data_value = node.in_port(0).data.get_value() + indices_shape = node.in_port(1).data.get_shape() + indices_value = node.in_port(1).data.get_value() + + assert node.has_valid('batch_dims'), "Node {} must contain `batch_dims` attribute".format(node_name) + batch_dims = node.batch_dims + + # check that a number of batch dimensions is less than both ranks of data and indices tensors + assert batch_dims < len(data_shape), "Number of batch dimensions must be less than a rank of data" + assert batch_dims < len(indices_shape), "Number of batch dimensions must be less than a rank of indices" + + # check that batch dimensions of data and indices are the same + for batch_dim in range(batch_dims): + assert data_shape[batch_dim] == indices_shape[batch_dim], \ + "The dimension {} for data and indices tensors must be the same".format(batch_dim) + + # check ranks of input tensors + assert len(data_shape) > 0, "Data must not be a scalar" + assert len(indices_shape) > 0, "Indices must not be a scalar" + assert (batch_dims + indices_shape[-1]) <= len(data_shape), \ + "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions" + + # compute output shape + number_batches = [np.prod(data_shape[:batch_dims]).tolist()] if batch_dims > 0 else list() + slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) + output_shape = number_batches + list(indices_shape[batch_dims:-1]) + slice_shape + node.out_port(0).data.set_shape(int64_array(output_shape)) + + # compute output value if all input values are defined + if data_value is not None and indices_value is not None: + output_value = np.zeros(output_shape, dtype=data_value.dtype) + if batch_dims == 0: + output_indices_range = int64_array(indices_shape[:-1]) + for output_index in np.ndindex(tuple(output_indices_range)): + indices_tuple = indices_value[output_index] + output_value[output_index] = data_value[tuple(indices_tuple.T)] + else: + batch_dims_range = int64_array(indices_shape[:batch_dims]) + for batch_indices in np.ndindex(tuple(batch_dims_range)): + # compute batch index in output tensor + batch_ind = 0 + num_elements = 1 + for ind in reversed(range(len(batch_dims_range))): + batch_ind += batch_indices[ind] * num_elements + num_elements *= batch_dims_range[ind] + output_indices_range = int64_array(indices_shape[batch_dims:-1]) + for output_index in np.ndindex(tuple(output_indices_range)): + tmp_ind = batch_indices + output_index + indices_tuple = tuple(indices_value[tmp_ind].T) + full_input_ind = batch_indices + indices_tuple + full_output_ind = tuple(np.array([batch_ind]).T) + output_index + output_value[full_output_ind] = data_value[full_input_ind] + node.out_port(0).data.set_value(output_value) diff --git a/model-optimizer/extensions/ops/gathernd_test.py b/model-optimizer/extensions/ops/gathernd_test.py new file mode 100644 index 00000000000000..da27f4968e8fa8 --- /dev/null +++ b/model-optimizer/extensions/ops/gathernd_test.py @@ -0,0 +1,254 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.ops.gathernd import GatherND +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node +from mo.utils.unittest.graph import build_graph + +nodes_attributes = {'data': {'kind': 'op'}, + 'data_data': {'shape': None, 'value': None, 'kind': 'data'}, + 'indices': {'kind': 'op'}, + 'indices_data': {'shape': None, 'value': None, 'kind': 'data'}, + 'gathernd_node': {'op': 'ScatterNDUpdate', 'kind': 'op', 'batch_dims': 0}, + 'output': {'shape': None, 'value': None, 'kind': 'data'}} + +# graph 1 +edges = [('data', 'data_data', {'in': 0}), + ('indices', 'indices_data', {'in': 1}), + ('data_data', 'gathernd_node', {'in': 0}), + ('indices_data', 'gathernd_node', {'in': 1}), + ('gathernd_node', 'output', {'out': 0})] + +# test data for partial infer: gather elements +inputs1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None}, + 'indices_data': {'shape': int64_array([3, 2]), 'value': None}} + +# test data for partial infer: gather slices +inputs2 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None}, + 'indices_data': {'shape': int64_array([3, 2]), 'value': None}} + +# test data for partial infer: gather slices and batch_dims=2 +inputs3 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None}, + 'indices_data': {'shape': int64_array([10, 40, 3, 5, 1]), 'value': None}} + +# test data for constant folding: gather elements, batch_dims = 0 +inputs4 = {'data_data': {'shape': int64_array([2, 2]), 'value': int64_array([[1, 2], + [3, 4]])}, + 'indices_data': {'shape': int64_array([2, 2]), 'value': int64_array([[0, 0], + [1, 0]])}} +output4 = int64_array([1, 3]) + +# test data for constant folding: gather slices, batch_dims = 0 +inputs5 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]])}, + 'indices_data': {'shape': int64_array([3, 2]), 'value': int64_array([[0, 1], + [1, 0], + [1, 2]])}} +output5 = int64_array([[5, 6, 7, 8], + [13, 14, 15, 16], + [21, 22, 23, 24]]) + +# test data for constant folding: gather slices, batch_dims = 1 +inputs6 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]])}, + 'indices_data': {'shape': int64_array([2, 1]), 'value': int64_array([[1], + [0]])}} +output6 = int64_array([[5, 6, 7, 8], + [13, 14, 15, 16]]) + +# test data for constant folding: gather slices with leading dimensions, batch_dims = 2 +inputs7 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]])}, + 'indices_data': {'shape': int64_array([2, 3, 1, 1]), 'value': int64_array([[[[1]], + [[0]], + [[2]]], + [[[0]], + [[2]], + [[2]]]])}} +output7 = int64_array([[2], [5], [11], [13], [19], [23]]) + +# test data for constant folding: gather elements, batch_dims = 2 +inputs8 = {'data_data': {'shape': int64_array([2, 3, 4, 2]), + 'value': int64_array([[[[1, 2], [3, 4], [5, 6], [7, 8]], + [[9, 10], [11, 12], [13, 14], [15, 16]], + [[17, 18], [19, 20], [21, 22], [23, 24]]], + [[[25, 26], [27, 28], [29, 30], [31, 32]], + [[33, 34], [35, 36], [37, 38], [39, 40]], + [[41, 42], [43, 44], [45, 46], [47, 48]]]])}, + 'indices_data': {'shape': int64_array([2, 3, 3, 2]), + 'value': int64_array([[[[1, 0], [3, 1], [2, 1]], + [[0, 1], [1, 1], [2, 0]], + [[3, 0], [3, 1], [2, 1]]], + [[[2, 0], [1, 1], [3, 1]], + [[1, 1], [2, 0], [2, 0]], + [[0, 0], [3, 1], [3, 1]]]])}} +output8 = int64_array([[3, 8, 6], + [10, 12, 13], + [23, 24, 22], + [29, 28, 32], + [36, 37, 37], + [41, 48, 48]]) + +# invalid test case with incorrect rank for indices +inputs_inv1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None}, + 'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}} + +# invalid test case with unequal batch dimensions, batch_dims = 2 +inputs_inv2 = {'data_data': {'shape': int64_array([10, 40, 20]), 'value': None}, + 'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}} + +# invalid test case with indices rank greater than a rank of data excluding batch dimensions, batch_dims = 2 +inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value': None}, + 'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}} + +class TestScatterNDUpdate(unittest.TestCase): + def setUp(self): + nodes_attributes['gathernd_node']['batch_dims'] = 0 + + def test_partial_infer_gather_element(self): + graph = build_graph(nodes_attributes, edges, inputs1) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([3]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice(self): + graph = build_graph(nodes_attributes, edges, inputs2) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([3, 30]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs3) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([400, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_infer4(self): + graph = build_graph(nodes_attributes, edges, inputs4) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output4, res_output_value), + 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + + def test_infer5(self): + graph = build_graph(nodes_attributes, edges, inputs5) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output5, res_output_value), + 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + + def test_infer6(self): + nodes_attributes['gathernd_node']['batch_dims'] = 1 + graph = build_graph(nodes_attributes, edges, inputs6) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output6, res_output_value), + 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + + def test_infer7(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs7) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output7, res_output_value), + 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + + def test_infer8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs8) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output8, res_output_value), + 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + + def test_infer_invalid1(self): + graph = build_graph(nodes_attributes, edges, inputs_inv1) + gathernd_node = Node(graph, 'gathernd_node') + self.assertRaises(AssertionError, GatherND.infer, gathernd_node) + + def test_infer_invalid2(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs_inv2) + gathernd_node = Node(graph, 'gathernd_node') + self.assertRaises(AssertionError, GatherND.infer, gathernd_node) + + def test_infer_invalid3(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs_inv3) + gathernd_node = Node(graph, 'gathernd_node') + self.assertRaises(AssertionError, GatherND.infer, gathernd_node)