Skip to content

Commit

Permalink
[ONNX] Gather 8 (#7185)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-jankowski authored Aug 26, 2021
1 parent 0bc991d commit e1226cc
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 7 deletions.
11 changes: 4 additions & 7 deletions ngraph/frontend/onnx/frontend/src/op/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

#include <memory>

#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "ngraph/validation_util.hpp"
#include "onnx_import/core/node.hpp"

Expand All @@ -21,15 +21,12 @@ inline OutputVector gather(const Node& node) {
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);

return {std::make_shared<default_opset::Gather>(data,
indices,
default_opset::Constant::create(element::i64, Shape{}, {axis}))};
return {std::make_shared<ngraph::opset8::Gather>(data,
indices,
default_opset::Constant::create(element::i64, Shape{}, {axis}))};
}

} // namespace set_1

} // namespace op

} // namespace onnx_import

} // namespace ngraph
65 changes: 65 additions & 0 deletions ngraph/test/models/onnx/gather_float_1D.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
ir_version: 7
graph {
node {
input: "data"
input: "indices"
output: "output"
name: "gather"
op_type: "Gather"
attribute {
name: "axis"
i: 0
type: INT
}
}
name: "test-gather8"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 13
}
71 changes: 71 additions & 0 deletions ngraph/test/models/onnx/gather_float_2D_axis_1.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
ir_version: 7
graph {
node {
input: "data"
input: "indices"
output: "output"
name: "gather"
op_type: "Gather"
attribute {
name: "axis"
i: 1
type: INT
}
}
name: "test-gather8"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 13
}
77 changes: 77 additions & 0 deletions ngraph/test/models/onnx/gather_int32_3D_axis_1.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
ir_version: 7
graph {
node {
input: "data"
input: "indices"
output: "output"
name: "gather"
op_type: "Gather"
attribute {
name: "axis"
i: 1
type: INT
}
}
name: "test-gather8"
input {
name: "data"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 13
}
77 changes: 77 additions & 0 deletions ngraph/test/models/onnx/gather_int8_3D_axis_neg_1.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
ir_version: 7
graph {
node {
input: "data"
input: "indices"
output: "output"
name: "gather"
op_type: "Gather"
attribute {
name: "axis"
i: -1
type: INT
}
}
name: "test-gather8"
input {
name: "data"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
domain: ""
version: 13
}
Loading

0 comments on commit e1226cc

Please sign in to comment.