Skip to content

Commit

Permalink
nGraph 'shell' implementation for GatherElements-6 and MO 'shell' imp…
Browse files Browse the repository at this point in the history
…lementation (openvinotoolkit#3467)

* Initial support of GatherElements in MO and nGraph

* apply_style

* added lost extractor for GatherElements

* Corrected GatherElements::validate_and_infer_types

* updated package_BOM.txt

* Type_t added

* started to implement ngraph shape_type_infer unit-tests

* finally implemented all ngraph shape_inference unit-tests

* updated Supported_Frameworks_Layers.md

* added correct handling of dynamic shapes in nGraph, added unit-tests for dynamic cases, fixed dump typos in MO, replaced axis type from int -> int64_t

* implemented shape infer for dynamic shapes with intervals

* finalized MO implementation

* applied comment from review

* style-apply

* spec correction

* removed conflict

* fixed typos

* removed obsolete comments form type_prop

* significant corrections in validate_and_infer_types

* style-apply

* data_rank check for axis
  • Loading branch information
pavel-esir authored and mryzhov committed Jan 14, 2021
1 parent c3509b6 commit 86bf405
Show file tree
Hide file tree
Showing 12 changed files with 772 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ Standard ONNX\* operators:
| Floor | No |
| GRU | No |
| Gather | No |
| GatherElements | Doesn't work with negative indices |
| GatherND | No |
| GatherTree | No |
| Gemm | No |
Expand Down
76 changes: 38 additions & 38 deletions docs/ops/movement/GatherElements_6.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,53 @@ For instance, in the 3D case (`r = 3`), the output is determined by the followin
```
Example 1 with concrete values:
```
data = [
[1, 2],
[3, 4],
]
indices = [
[0, 1],
[0, 0],
]
axis = 0
output = [
[1, 4],
[1, 2],
]
data = [
[1, 2],
[3, 4],
]
indices = [
[0, 1],
[0, 0],
]
axis = 0
output = [
[1, 4],
[1, 2],
]
```
Example 2 with `axis` = 1 and `indices` having greater (than `data`) shape:
```
data = [
[1, 7],
[4, 3],
]
indices = [
[1, 1, 0],
[1, 0, 1],
]
axis = 1
output = [
[7, 7, 1],
[3, 4, 3],
]
[1, 7],
[4, 3],
]
indices = [
[1, 1, 0],
[1, 0, 1],
]
axis = 1
output = [
[7, 7, 1],
[3, 4, 3],
]
```

Example 3 `indices` has lesser (than `data`) shape:
```
data = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]
indices = [
[1, 0, 1],
[1, 2, 0],
]
axis = 0
output = [
[4, 2, 6],
[4, 8, 3],
]
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]
indices = [
[1, 0, 1],
[1, 2, 0],
]
axis = 0
output = [
[4, 2, 6],
[4, 8, 3],
]
```

**Attributes**:
Expand Down
2 changes: 2 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ extensions/front/onnx/flatten_ext.py
extensions/front/onnx/flattenONNX_to_reshape.py
extensions/front/onnx/fused_bn_ext.py
extensions/front/onnx/gather_ext.py
extensions/front/onnx/gatherelements_ext.py
extensions/front/onnx/gathernd_ext.py
extensions/front/onnx/gemm_ext.py
extensions/front/onnx/group_norm_ext.py
Expand Down Expand Up @@ -635,6 +636,7 @@ extensions/ops/ExtractImagePatches.py
extensions/ops/fake_output.py
extensions/ops/fakequantize.py
extensions/ops/gather.py
extensions/ops/gatherelements.py
extensions/ops/gathernd.py
extensions/ops/GatherTree.py
extensions/ops/gelu.py
Expand Down
32 changes: 32 additions & 0 deletions model-optimizer/extensions/front/onnx/gatherelements_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Copyright (C) 2017-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.ops.gatherelements import GatherElements
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr


class GatherElementsFrontExtractor(FrontExtractorOp):
op = 'GatherElements'
enabled = True

@classmethod
def extract(cls, node):
attrs = {
'axis': onnx_attr(node, 'axis', 'i', default=0)
}
GatherElements.update_node_stat(node, attrs)
return cls.enabled
75 changes: 75 additions & 0 deletions model-optimizer/extensions/ops/gatherelements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Copyright (C) 2017-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs


class GatherElements(Op):
op = 'GatherElements'

def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': self.op,
'version': 'opset6',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
'axis': 0,
}, attrs)

def backend_attrs(self):
return ['axis']

@staticmethod
def infer(node: Node):
data_shape = node.in_port(0).data.get_shape()
indices_shape = node.in_port(1).data.get_shape()
axis = node.axis
data_rank = len(data_shape)

assert data_rank >= 1, 'data_rank must be >= 1'
assert data_rank == len(indices_shape), 'data and indices inputs for node {} must be of the ' \
'same rank. Instead got {} and {}'.\
format(node.name, data_rank, len(indices_shape))
assert -data_rank <= axis < data_rank, 'axis for node {0} must be within interval ' \
'[-{1}}, {1} - 1]. Instead got: axis={2}'.\
format(node.name, data_rank, axis)
if axis < 0:
axis += data_rank

for idx, (data_sz, ind_sz) in enumerate(zip(data_shape, indices_shape)):
if idx != axis and data_sz != ind_sz:
raise ValueError('Sizes along axis {} for node {} do not match. data and indices must have '
'equal size along all axes except for axis {}'.format(idx, node.name, axis))

data = node.in_port(0).data.get_value()
indices = node.in_port(1).data.get_value()

if data is not None and indices is not None:
out_value = np.empty(indices_shape, dtype=data.dtype)
for idx in np.ndindex(*indices_shape):
data_idx = list(idx)
data_idx[node.axis] = indices[idx]
out_value[idx] = data[tuple(data_idx)]
node.out_port(0).data.set_value(out_value)
else:
node.out_port(0).data.set_shape(indices_shape)

PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
115 changes: 115 additions & 0 deletions model-optimizer/extensions/ops/gatherelements_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Copyright (C) 2017-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np
from generator import generator, generate

from extensions.ops.gatherelements import GatherElements
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, connect, \
valued_const_with_data


@generator
class GatherElementsInferTest(unittest.TestCase):
@generate(*[
([[1, 2],
[3, 4]],
[[0, 1],
[0, 0]],
0, # axis
[[1, 4], # ref_res
[1, 2]]),
([[1, 2],
[3, 4]],
[[0, 1],
[0, 0]],
1, # axis
[[1, 2], # ref_res
[3, 3]]),
([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[1, 2, 0],
[2, 0, 0]],
0, # axis
[[4, 8, 3], # ref_res
[7, 2, 3]]),
([[1, 2],
[3, 4]],
[[0, 1],
[0, 0]],
-1, # axis
[[1, 2], # ref_res
[3, 3]]),
([ # 3D case
[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]
],
[
[[1, 0],
[0, 1]],
[[1, 1],
[1, 0]],
[[0, 0],
[1, 1]]
],
-1, # axis
[
[[2, 1],
[3, 4]],
[[6, 6],
[8, 7]],
[[9, 9],
[12, 12]]
]),
])

def test_gatherelements_value_infer(self, data, indices, axis, ref_res):
nodes = {
**valued_const_with_data('data', int64_array(data)),
**valued_const_with_data('indices', int64_array(indices)),
**regular_op_with_empty_data('gather_elements', {'op': 'GatherElements', 'axis': axis}),
**result()
}

graph = build_graph(nodes_attrs=nodes, edges=[
*connect('data', '0:gather_elements'),
*connect('indices', '1:gather_elements'),
*connect('gather_elements', 'output')
], nodes_with_edges_only=True)
graph.stage = 'middle'

gather_el_node = Node(graph, 'gather_elements')
GatherElements.infer(gather_el_node)

res_output_shape = gather_el_node.out_node().shape
self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape))

res_output_value = gather_el_node.out_node().value
if res_output_value is not None:
self.assertTrue(np.array_equal(int64_array(ref_res), res_output_value))
55 changes: 55 additions & 0 deletions ngraph/core/include/ngraph/op/gather_elements.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include "ngraph/op/op.hpp"

namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief GatherElements operation
///
class NGRAPH_API GatherElements : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
GatherElements() = default;

/// \brief Constructs a GatherElements operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const int64_t axis);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

int64_t get_axis() const { return m_axis; }
private:
int64_t m_axis;
};
}
}
}
Loading

0 comments on commit 86bf405

Please sign in to comment.