Skip to content

Commit

Permalink
Extend MO for operation GatherND (openvinotoolkit#2540)
Browse files Browse the repository at this point in the history
* Extend MO for operation GatherND

* Update documentation

* Rename GatherNd.py to gathernd.py

Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants authored and mryzhov committed Dec 15, 2020
1 parent b06e5d0 commit 5235ce1
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 56 deletions.
3 changes: 2 additions & 1 deletion docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Standard TensorFlow\* operations:
| FloorDiv | No |
| FusedBatchNorm | No |
| Gather | No |
| GatherNd | Supported if it can be replaced with Gather |
| GatherNd | No |
| GatherV2 | No |
| Greater | No |
| GreaterEqual | No |
Expand Down Expand Up @@ -337,6 +337,7 @@ Standard ONNX\* operators:
| Floor | No |
| GRU | No |
| Gather | No |
| GatherND | No |
| GatherTree | No |
| Gemm | No |
| GlobalAveragePool | No |
Expand Down
4 changes: 3 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ extensions/front/onnx/expand_ext.py
extensions/front/onnx/flatten_ext.py
extensions/front/onnx/flattenONNX_to_reshape.py
extensions/front/onnx/gather_ext.py
extensions/front/onnx/gathernd_ext.py
extensions/front/onnx/gemm_ext.py
extensions/front/onnx/group_norm_ext.py
extensions/front/onnx/gru_ext.py
Expand Down Expand Up @@ -382,6 +383,7 @@ extensions/front/tf/FlattenToReshape.py
extensions/front/tf/floor_div_decomposition.py
extensions/front/tf/floor_ext.py
extensions/front/tf/gather_ext.py
extensions/front/tf/gathernd_ext.py
extensions/front/tf/GatherTree_ext.py
extensions/front/tf/GNMT_DynamicSequenceLengths.py
extensions/front/tf/identity_ext.py
Expand Down Expand Up @@ -617,7 +619,7 @@ extensions/ops/ExtractImagePatches.py
extensions/ops/fake_output.py
extensions/ops/fakequantize.py
extensions/ops/gather.py
extensions/ops/GatherNd.py
extensions/ops/gathernd.py
extensions/ops/GatherTree.py
extensions/ops/gelu.py
extensions/ops/grn.py
Expand Down
32 changes: 32 additions & 0 deletions model-optimizer/extensions/front/onnx/gathernd_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.ops.gathernd import GatherND
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr


class GatherNDFrontExtractor(FrontExtractorOp):
op = 'GatherND'
enabled = True

@classmethod
def extract(cls, node):
attrs = {
'batch_dims': onnx_attr(node, 'batch_dims', 'i', default=0)
}
GatherND.update_node_stat(node, attrs)
return cls.enabled
30 changes: 30 additions & 0 deletions model-optimizer/extensions/front/tf/gathernd_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.gathernd import GatherND
from mo.front.extractor import FrontExtractorOp


class GatherNDFrontExtractor(FrontExtractorOp):
op = 'GatherNd'
enabled = True

@classmethod
def extract(cls, node):
attrs = {
'batch_dims': 0,
}
GatherND.update_node_stat(node, attrs)
return cls.enabled
17 changes: 10 additions & 7 deletions model-optimizer/extensions/middle/GatherNdNormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
from mo.ops.reshape import Reshape


class GatherNdNormalize(MiddleReplacementPattern):
class GatherNDNormalize(MiddleReplacementPattern):
"""
Hot fix for new speech-to-text model enabling while GatherND is not implemented in IE.
We can replace GatherNd to Reshape + Gather in case when GatherNd indices have just one
We can replace GatherND to Reshape + Gather in case when GatherND indices have just one
meaningful dimension.
TODO: Investigate whether we must replace GatherND with Reshape + Gather always (due to performance benefits)
for this particular case or only if the plugin does not support GatherND.
And the best place for the transformation is nGraph so we need to move it.
"""
enabled = True
force_clean_up = True
Expand All @@ -44,7 +47,7 @@ def run_after(self):

def pattern(self):
return dict(
nodes=[('GatherNd', dict(kind='op', op='GatherNd'))],
nodes=[('GatherND', dict(kind='op', op='GatherND', batch_dims=0))],
edges=[]
)

Expand All @@ -67,24 +70,24 @@ def indices_check(indices: np.array, input_shape: tuple):
return non_zero

def replace_pattern(self, graph: Graph, match: dict):
gather = match['GatherNd']
gather = match['GatherND']
gather_name = gather.soft_get('name', gather.id)
input_shape = gather.in_node(0).shape
indices = gather.in_node(1).value
if indices is None:
# We can't do such special pass without indices value
return

# 0. All needed checks that we can replace GatherNd by Gather
# 0. All needed checks that we can replace GatherND by Gather
gather_idx = self.indices_check(indices, input_shape)
if gather_idx is None:
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name))
log.warning('Node {} with op=GatherND can\'t be normalized to op=Gather.'.format(gather_name))
return

# 1. Add Reshape and connect
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
{'name': gather_name + '/Reshape_for_GatherNd/'})
{'name': gather_name + '/Reshape_for_GatherND/'})
gather.in_port(0).get_connection().set_destination(reshape.in_port(0))

# 2. Change indices from Nd to 1d:
Expand Down
47 changes: 0 additions & 47 deletions model-optimizer/extensions/ops/GatherNd.py

This file was deleted.

102 changes: 102 additions & 0 deletions model-optimizer/extensions/ops/gathernd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op


class GatherND(Op):
op = 'GatherND'

def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': self.op,
'op': self.op,
'version': 'opset5',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
'batch_dims': 0
}
super().__init__(graph, mandatory_props, attrs)

def backend_attrs(self):
return ['batch_dims']

@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) == 2, \
"Incorrect number of inputs for {} node".format(node_name)

data_shape = node.in_port(0).data.get_shape()
data_value = node.in_port(0).data.get_value()
indices_shape = node.in_port(1).data.get_shape()
indices_value = node.in_port(1).data.get_value()

assert node.has_valid('batch_dims'), "Node {} must contain `batch_dims` attribute".format(node_name)
batch_dims = node.batch_dims

# check that a number of batch dimensions is less than both ranks of data and indices tensors
assert batch_dims < len(data_shape), "Number of batch dimensions must be less than a rank of data"
assert batch_dims < len(indices_shape), "Number of batch dimensions must be less than a rank of indices"

# check that batch dimensions of data and indices are the same
for batch_dim in range(batch_dims):
assert data_shape[batch_dim] == indices_shape[batch_dim], \
"The dimension {} for data and indices tensors must be the same".format(batch_dim)

# check ranks of input tensors
assert len(data_shape) > 0, "Data must not be a scalar"
assert len(indices_shape) > 0, "Indices must not be a scalar"
assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"

# compute output shape
number_batches = [np.prod(data_shape[:batch_dims]).tolist()] if batch_dims > 0 else list()
slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])
output_shape = number_batches + list(indices_shape[batch_dims:-1]) + slice_shape
node.out_port(0).data.set_shape(int64_array(output_shape))

# compute output value if all input values are defined
if data_value is not None and indices_value is not None:
output_value = np.zeros(output_shape, dtype=data_value.dtype)
if batch_dims == 0:
output_indices_range = int64_array(indices_shape[:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
indices_tuple = indices_value[output_index]
output_value[output_index] = data_value[tuple(indices_tuple.T)]
else:
batch_dims_range = int64_array(indices_shape[:batch_dims])
for batch_indices in np.ndindex(tuple(batch_dims_range)):
# compute batch index in output tensor
batch_ind = 0
num_elements = 1
for ind in reversed(range(len(batch_dims_range))):
batch_ind += batch_indices[ind] * num_elements
num_elements *= batch_dims_range[ind]
output_indices_range = int64_array(indices_shape[batch_dims:-1])
for output_index in np.ndindex(tuple(output_indices_range)):
tmp_ind = batch_indices + output_index
indices_tuple = tuple(indices_value[tmp_ind].T)
full_input_ind = batch_indices + indices_tuple
full_output_ind = tuple(np.array([batch_ind]).T) + output_index
output_value[full_output_ind] = data_value[full_input_ind]
node.out_port(0).data.set_value(output_value)
Loading

0 comments on commit 5235ce1

Please sign in to comment.