Skip to content

Commit

Permalink
[ONNX FE] Improved a method of operators registration (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#15990)

* initial version of implementation

* styles applied

* fixed and registration

* add more unit tests

* fixed and in legacy opset

* review remarks

* refactor of version name range
  • Loading branch information
Mateusz Bencer authored and mryzhov committed Mar 17, 2023
1 parent 1044df5 commit 3461b52
Show file tree
Hide file tree
Showing 16 changed files with 313 additions and 21 deletions.
13 changes: 12 additions & 1 deletion src/frontends/onnx/frontend/src/op/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>

#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

Expand All @@ -15,11 +16,21 @@ namespace onnx_import {
namespace op {
namespace set_1 {
inline OutputVector abs(const Node& node) {
CHECK_VALID_NODE(node,
!node.has_attribute("consumed_inputs"),
"consumed_inputs legacy attribute of Abs op is not supported");
return {std::make_shared<default_opset::Abs>(node.get_ng_inputs().at(0))};
}

} // namespace set_1

namespace set_6 {
using set_1::abs;
} // namespace set_6

namespace set_13 {
using set_6::abs;
} // namespace set_13

} // namespace op

} // namespace onnx_import
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/onnx/frontend/src/op/acos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
namespace set_7 {
inline OutputVector acos(const Node& node) {
return {std::make_shared<default_opset::Acos>(node.get_ng_inputs().at(0))};
}
} // namespace set_1
} // namespace set_7

} // namespace op

Expand Down
4 changes: 2 additions & 2 deletions src/frontends/onnx/frontend/src/op/acosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
namespace set_9 {
inline OutputVector acosh(const Node& node) {
return {std::make_shared<default_opset::Acosh>(node.get_ng_inputs().at(0))};
}
} // namespace set_1
} // namespace set_9

} // namespace op

Expand Down
12 changes: 10 additions & 2 deletions src/frontends/onnx/frontend/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "op/add.hpp"

#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/shape.hpp"
#include "utils/common.hpp"
Expand All @@ -14,16 +15,23 @@ namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector add(const Node& node) {
CHECK_VALID_NODE(node,
!node.has_attribute("consumed_inputs"),
"consumed_inputs legacy attribute of Add op is not supported");
return common::handle_opset6_binary_op<default_opset::Add>(node);
}

} // namespace set_1

namespace set_6 {
OutputVector add(const Node& node) {
return common::handle_opset6_binary_op<default_opset::Add>(node);
}
} // namespace set_6

namespace set_7 {
OutputVector add(const Node& node) {
return {std::make_shared<default_opset::Add>(node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
}

} // namespace set_7

} // namespace op
Expand Down
13 changes: 13 additions & 0 deletions src/frontends/onnx/frontend/src/op/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@ OutputVector add(const Node& node);

} // namespace set_1

namespace set_6 {
OutputVector add(const Node& node);

} // namespace set_6

namespace set_7 {
OutputVector add(const Node& node);

} // namespace set_7

namespace set_13 {
using set_7::add;
} // namespace set_13

namespace set_14 {
using set_13::add;
} // namespace set_14

} // namespace op

} // namespace onnx_import
Expand Down
12 changes: 8 additions & 4 deletions src/frontends/onnx/frontend/src/op/and.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "onnx_import/core/node.hpp"
#include "utils/common.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
inline OutputVector logical_and(const Node& node) {
return {std::make_shared<default_opset::LogicalAnd>(node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
return common::handle_opset6_binary_op<default_opset::LogicalAnd>(node);
}

} // namespace set_1

} // namespace op
namespace set_7 {
inline OutputVector logical_and(const Node& node) {
return {std::make_shared<default_opset::LogicalAnd>(node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
}
} // namespace set_7

} // namespace op
} // namespace onnx_import

} // namespace ngraph
40 changes: 32 additions & 8 deletions src/frontends/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@
#include "op/where.hpp"
#include "op/xor.hpp"

using namespace ov::frontend::onnx;

namespace ngraph {
namespace onnx_import {
namespace {
Expand All @@ -195,6 +197,22 @@ typename Container::const_iterator find(int64_t version, const Container& map) {
}
} // namespace

void OperatorsBridge::register_operator_in_custom_domain(std::string name,
VersionRange range,
Operator fn,
std::string domain,
std::string warning_mes) {
for (int version = range.m_since; version <= range.m_until; ++version) {
register_operator(name, version, domain, fn);
}
NGRAPH_WARN << "Operator: " << name << " since version: " << range.m_since << " until version: " << range.m_until
<< " registered with warning: " << warning_mes;
}

void OperatorsBridge::register_operator(std::string name, VersionRange range, Operator fn, std::string warning_mes) {
register_operator_in_custom_domain(name, range, std::move(fn), "", warning_mes);
}

void OperatorsBridge::register_operator(const std::string& name,
int64_t version,
const std::string& domain,
Expand Down Expand Up @@ -243,9 +261,9 @@ OperatorSet OperatorsBridge::get_operator_set(const std::string& domain, int64_t
NGRAPH_DEBUG << "Domain '" << domain << "' not recognized by nGraph";
return result;
}
if (domain == "" && version > OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION) {
if (domain == "" && version > LATEST_SUPPORTED_ONNX_OPSET_VERSION) {
NGRAPH_WARN << "Currently ONNX operator set version: " << version
<< " is unsupported. Falling back to: " << OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION;
<< " is unsupported. Falling back to: " << LATEST_SUPPORTED_ONNX_OPSET_VERSION;
}
for (const auto& op : dm->second) {
const auto& it = find(version, op.second);
Expand Down Expand Up @@ -293,12 +311,18 @@ static const char* const PYTORCH_ATEN_DOMAIN = "org.pytorch.aten";
m_map[domain_][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1));

OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Acos", 1, acos);
REGISTER_OPERATOR("Acosh", 1, acosh);
REGISTER_OPERATOR("Add", 1, add);
REGISTER_OPERATOR("Add", 7, add);
REGISTER_OPERATOR("And", 1, logical_and);
register_operator("Abs", VersionRange{1, 5}, op::set_1::abs, "Legacy consumed_inputs is not supported");
register_operator("Abs", VersionRange::since(6), op::set_6::abs);
register_operator("Acos", VersionRange::single_version_for_all_opsets(), op::set_7::acos);
register_operator("Acosh", VersionRange::single_version_for_all_opsets(), op::set_9::acosh);
register_operator("Add", VersionRange{1, 5}, op::set_1::add, "Legacy consumed_inputs is not supported");
register_operator("Add", VersionRange::in(6), op::set_6::add);
register_operator("Add", VersionRange{7, 12}, op::set_7::add);
register_operator("Add", VersionRange::in(13), op::set_13::add);
register_operator("Add", VersionRange::since(14), op::set_14::add);
register_operator("And", VersionRange{1, 6}, op::set_1::logical_and);
register_operator("And", VersionRange::since(6), op::set_7::logical_and);
// 101468 - Use the VersionRange-based approach for all operators
REGISTER_OPERATOR("ArgMin", 1, argmin);
REGISTER_OPERATOR("ArgMin", 12, argmin);
REGISTER_OPERATOR("ArgMax", 1, argmax);
Expand Down
12 changes: 10 additions & 2 deletions src/frontends/onnx/frontend/src/ops_bridge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "ngraph/except.hpp"
#include "onnx_import/core/operator_set.hpp"
#include "version_range.hpp"

namespace ngraph {
namespace onnx_import {
Expand All @@ -35,8 +36,6 @@ struct UnsupportedVersion : ngraph_error {

class OperatorsBridge {
public:
static constexpr const int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;

OperatorsBridge();

OperatorsBridge(const OperatorsBridge&) = default;
Expand Down Expand Up @@ -77,6 +76,15 @@ class OperatorsBridge {
void overwrite_operator(const std::string& name, const std::string& domain, Operator fn);

private:
void register_operator_in_custom_domain(std::string name,
ov::frontend::onnx::VersionRange range,
Operator fn,
std::string domain,
std::string warning_mes = "");
void register_operator(std::string name,
ov::frontend::onnx::VersionRange range,
Operator fn,
std::string warning_mes = "");
// Registered operators structure
// {
// domain_1: {
Expand Down
1 change: 1 addition & 0 deletions src/frontends/onnx/frontend/src/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ template OutputVector handle_opset6_binary_op<default_opset::Add>(const Node& no
template OutputVector handle_opset6_binary_op<default_opset::Divide>(const Node& node);
template OutputVector handle_opset6_binary_op<default_opset::Multiply>(const Node& node);
template OutputVector handle_opset6_binary_op<default_opset::Subtract>(const Node& node);
template OutputVector handle_opset6_binary_op<default_opset::LogicalAnd>(const Node& node);

const std::string FAILSAFE_NODE = "ONNX_FAILSAFE_NODE";

Expand Down
32 changes: 32 additions & 0 deletions src/frontends/onnx/frontend/src/version_range.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

namespace ov {
namespace frontend {
namespace onnx {

constexpr int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;
struct VersionRange {
constexpr VersionRange(int since_version, int until_version) : m_since(since_version), m_until(until_version) {}
static constexpr VersionRange since(int since_version) {
return VersionRange{since_version, LATEST_SUPPORTED_ONNX_OPSET_VERSION};
}
static constexpr VersionRange until(int until_version) {
return VersionRange{1, until_version};
}
static constexpr VersionRange in(int version) {
return VersionRange{version, version};
}
static constexpr VersionRange single_version_for_all_opsets() {
return VersionRange{1, LATEST_SUPPORTED_ONNX_OPSET_VERSION};
}
// -1 means that that a left/right boundary of the range was not specified
const int m_since = -1, m_until = -1;
};

} // namespace onnx
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/onnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ set(SRC
telemetry.cpp
lib_close.cpp
model_support_tests.cpp
onnx_ops_registration.cpp
onnx_reader_external_data.cpp
skip_tests_config.cpp)

Expand Down Expand Up @@ -134,6 +135,9 @@ target_compile_definitions(ov_onnx_frontend_tests
SHARED_LIB_PREFIX="${CMAKE_SHARED_LIBRARY_PREFIX}"
SHARED_LIB_SUFFIX="${IE_BUILD_POSTFIX}${CMAKE_SHARED_LIBRARY_SUFFIX}")

set(ONNX_OPSET_VERSION 17 CACHE INTERNAL "Supported version of ONNX operator set")
target_compile_definitions(ov_onnx_frontend_tests PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})

if(ONNX_TESTS_DEPENDENCIES)
add_dependencies(ov_onnx_frontend_tests ${ONNX_TESTS_DEPENDENCIES})
endif()
Expand Down
45 changes: 45 additions & 0 deletions src/frontends/onnx/tests/models/abs.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
ir_version: 3
producer_name: "ONNX FE"
graph {
node {
input: "x"
output: "y"
op_type: "Abs"
}
name: "abs_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 4
}
Loading

0 comments on commit 3461b52

Please sign in to comment.