Skip to content

Commit

Permalink
[ONNX] Frontend refactoring (openvinotoolkit#21931)
Browse files Browse the repository at this point in the history
* Updated graph.hpp/cpp
* Added check of ONNX Runtime integration for arm64
* Fixed code style in null_node.hpp
Co-authored-by: Roman Kazantsev <[email protected]>
* Update null_node.cpp
* Refactored exceptions in attribute.hpp/cpp
---------
Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
gkrivor authored Jan 4, 2024
1 parent f0cffc4 commit 72e774f
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 125 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/linux_arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ jobs:

ONNX_Runtime:
name: ONNX Runtime Integration
if: fromJSON(needs.smart_ci.outputs.affected_components).ONNX_RT
if: fromJSON(needs.smart_ci.outputs.affected_components).ONNX_RT ||
fromJSON(needs.smart_ci.outputs.affected_components).ONNX_FE
needs: [ Build, Smart_CI ]
uses: ./.github/workflows/job_onnx_runtime.yml
with:
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/onnx/frontend/include/onnx_import/core/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ namespace ngraph {
namespace onnx_import {
namespace error {
namespace node {
struct UnknownAttribute : ngraph_error {
struct UnknownAttribute : ov::Exception {
OPENVINO_SUPPRESS_DEPRECATED_START
explicit UnknownAttribute(const std::string& node, const std::string& name)
: ngraph_error{"Node (" + node + "): unknown attribute \'" + name + "\'"} {}
: ov::Exception{"Node (" + node + "): unknown attribute \'" + name + "\'"} {}
OPENVINO_SUPPRESS_DEPRECATED_END
};

Expand Down
18 changes: 13 additions & 5 deletions src/frontends/onnx/frontend/include/onnx_import/core/null_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@

#include <memory>

#include "ngraph/deprecated.hpp"
#include "onnx_import/onnx_importer_visibility.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/op/op.hpp"

namespace ngraph {
namespace ov {
namespace op {
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const ngraph::Node* node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const Output<ngraph::Node>& output);
namespace util {
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const ov::Node* node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const std::shared_ptr<ov::Node>& node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const Output<ov::Node>& output);
} // namespace util
} // namespace op
} // namespace ov
namespace ngraph {
namespace op {
using namespace ov::op::util;
}

namespace onnx_import {
/// \brief Represents a missing optional input or output of an ONNX node
///
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/frontend/src/core/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ngraph {
namespace onnx_import {
Subgraph Attribute::get_subgraph(Graph* parent_graph) const {
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) {
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "GRAPH");
}

auto model_proto = std::make_shared<ONNX_NAMESPACE::ModelProto>();
Expand Down
59 changes: 21 additions & 38 deletions src/frontends/onnx/frontend/src/core/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,19 @@ class Model;
// of ONNX generated wrappers.
using AttributeProto_AttributeType = decltype(ONNX_NAMESPACE::AttributeProto{}.type());

namespace error {
namespace attribute {
namespace detail {
OPENVINO_SUPPRESS_DEPRECATED_START
struct Attribute : ngraph_error {
Attribute(const std::string& msg, AttributeProto_AttributeType type) : ngraph_error{msg} {}
};
OPENVINO_SUPPRESS_DEPRECATED_END

} // namespace detail

struct InvalidData : detail::Attribute {
explicit InvalidData(AttributeProto_AttributeType type) : Attribute{"invalid attribute type", type} {}
};

struct UnsupportedType : detail::Attribute {
explicit UnsupportedType(AttributeProto_AttributeType type) : Attribute{"unsupported attribute type", type} {}
};

} // namespace attribute

} // namespace error

namespace detail {
namespace attribute {
template <typename T>
inline T get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
throw ngraph::onnx_import::error::attribute::UnsupportedType{attribute.type()};
OPENVINO_THROW("Unsupported attribute type");
}

#define ONNX_INVALID_ATTR(attr, expected) \
OPENVINO_THROW("Invalid attribute type ", \
ONNX_NAMESPACE::AttributeProto_AttributeType_Name(attr), \
" expected: ", \
expected)

template <>
inline float get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
switch (attribute.type()) {
Expand All @@ -62,7 +45,7 @@ inline float get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
return attribute.f();
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

Expand All @@ -78,7 +61,7 @@ inline std::vector<float> get_value(const ONNX_NAMESPACE::AttributeProto& attrib
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
}
}

Expand All @@ -90,7 +73,7 @@ inline double get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
return static_cast<double>(attribute.i());
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

Expand All @@ -110,7 +93,7 @@ inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attri
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
}
#if defined(_MSC_VER)
# pragma warning(pop)
Expand All @@ -120,7 +103,7 @@ inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attri
template <>
inline std::size_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return static_cast<std::size_t>(attribute.i());
}
Expand All @@ -133,14 +116,14 @@ inline std::vector<std::size_t> get_value(const ONNX_NAMESPACE::AttributeProto&
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline int64_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return attribute.i();
}
Expand All @@ -153,14 +136,14 @@ inline std::vector<int64_t> get_value(const ONNX_NAMESPACE::AttributeProto& attr
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline std::string get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_STRING) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "STRING");
}
return attribute.s();
}
Expand All @@ -173,7 +156,7 @@ inline std::vector<std::string> get_value(const ONNX_NAMESPACE::AttributeProto&
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS:
return {std::begin(attribute.strings()), std::end(attribute.strings())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "STRING, STRINGS");
}
}

Expand Down Expand Up @@ -320,7 +303,7 @@ class Attribute {
if (is_tensor()) {
return Tensor{m_attribute_proto->t(), m_model_dir, m_mmap_cache};
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "TENSOR");
}

template <typename T, typename std::enable_if<std::is_same<T, std::vector<Tensor>>::value, bool>::type = true>
Expand All @@ -330,15 +313,15 @@ class Attribute {
} else if (is_tensor_array()) {
return get_tensor_array();
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "TENSOR, TENSORS");
}

template <typename T, typename std::enable_if<std::is_same<T, SparseTensor>::value, bool>::type = true>
T get_value() const {
if (is_sparse_tensor()) {
return SparseTensor{m_attribute_proto->sparse_tensor(), m_model_dir, m_mmap_cache};
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "SPARSE_TENSOR");
}

template <typename T, typename std::enable_if<std::is_same<T, std::vector<SparseTensor>>::value, bool>::type = true>
Expand All @@ -348,7 +331,7 @@ class Attribute {
} else if (is_sparse_tensor_array()) {
return get_sparse_tensor_array();
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "SPARSE_TENSOR, SPARSE_TENSORS");
}

ov::Any get_any() const;
Expand Down
Loading

0 comments on commit 72e774f

Please sign in to comment.