Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Frontend refactoring #21931

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/linux_arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ jobs:

ONNX_Runtime:
name: ONNX Runtime Integration
if: fromJSON(needs.smart_ci.outputs.affected_components).ONNX_RT
if: fromJSON(needs.smart_ci.outputs.affected_components).ONNX_RT ||
fromJSON(needs.smart_ci.outputs.affected_components).ONNX_FE
needs: [ Build, Smart_CI ]
uses: ./.github/workflows/job_onnx_runtime.yml
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ namespace ngraph {
namespace onnx_import {
namespace error {
namespace node {
struct UnknownAttribute : ngraph_error {
struct UnknownAttribute : ov::Exception {
OPENVINO_SUPPRESS_DEPRECATED_START
explicit UnknownAttribute(const std::string& node, const std::string& name)
: ngraph_error{"Node (" + node + "): unknown attribute \'" + name + "\'"} {}
: ov::Exception{"Node (" + node + "): unknown attribute \'" + name + "\'"} {}
p-durandin marked this conversation as resolved.
Show resolved Hide resolved
OPENVINO_SUPPRESS_DEPRECATED_END
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@

#include <memory>

#include "ngraph/deprecated.hpp"
#include "onnx_import/onnx_importer_visibility.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/op/op.hpp"

namespace ngraph {
namespace ov {
namespace op {
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const ngraph::Node* node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const Output<ngraph::Node>& output);
namespace util {
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const ov::Node* node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const std::shared_ptr<ov::Node>& node);
NGRAPH_API_DEPRECATED ONNX_IMPORTER_API bool is_null(const Output<ov::Node>& output);
} // namespace util
} // namespace op
} // namespace ov
namespace ngraph {
namespace op {
using namespace ov::op::util;
}

namespace onnx_import {
/// \brief Represents a missing optional input or output of an ONNX node
///
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/frontend/src/core/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ngraph {
namespace onnx_import {
Subgraph Attribute::get_subgraph(Graph* parent_graph) const {
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) {
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "GRAPH");
}

auto model_proto = std::make_shared<ONNX_NAMESPACE::ModelProto>();
Expand Down
59 changes: 21 additions & 38 deletions src/frontends/onnx/frontend/src/core/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,19 @@ class Model;
// of ONNX generated wrappers.
using AttributeProto_AttributeType = decltype(ONNX_NAMESPACE::AttributeProto{}.type());

namespace error {
namespace attribute {
namespace detail {
OPENVINO_SUPPRESS_DEPRECATED_START
struct Attribute : ngraph_error {
Attribute(const std::string& msg, AttributeProto_AttributeType type) : ngraph_error{msg} {}
};
OPENVINO_SUPPRESS_DEPRECATED_END

} // namespace detail

struct InvalidData : detail::Attribute {
explicit InvalidData(AttributeProto_AttributeType type) : Attribute{"invalid attribute type", type} {}
};

struct UnsupportedType : detail::Attribute {
explicit UnsupportedType(AttributeProto_AttributeType type) : Attribute{"unsupported attribute type", type} {}
};

} // namespace attribute

} // namespace error

namespace detail {
namespace attribute {
template <typename T>
inline T get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
throw ngraph::onnx_import::error::attribute::UnsupportedType{attribute.type()};
OPENVINO_THROW("Unsupported attribute type");
}

#define ONNX_INVALID_ATTR(attr, expected) \
OPENVINO_THROW("Invalid attribute type ", \
ONNX_NAMESPACE::AttributeProto_AttributeType_Name(attr), \
" expected: ", \
expected)

template <>
inline float get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
switch (attribute.type()) {
Expand All @@ -62,7 +45,7 @@ inline float get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
return attribute.f();
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

Expand All @@ -78,7 +61,7 @@ inline std::vector<float> get_value(const ONNX_NAMESPACE::AttributeProto& attrib
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
}
}

Expand All @@ -90,7 +73,7 @@ inline double get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
return static_cast<double>(attribute.i());
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

Expand All @@ -110,7 +93,7 @@ inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attri
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
}
#if defined(_MSC_VER)
# pragma warning(pop)
Expand All @@ -120,7 +103,7 @@ inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attri
template <>
inline std::size_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return static_cast<std::size_t>(attribute.i());
}
Expand All @@ -133,14 +116,14 @@ inline std::vector<std::size_t> get_value(const ONNX_NAMESPACE::AttributeProto&
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline int64_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return attribute.i();
}
Expand All @@ -153,14 +136,14 @@ inline std::vector<int64_t> get_value(const ONNX_NAMESPACE::AttributeProto& attr
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline std::string get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_STRING) {
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "STRING");
}
return attribute.s();
}
Expand All @@ -173,7 +156,7 @@ inline std::vector<std::string> get_value(const ONNX_NAMESPACE::AttributeProto&
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS:
return {std::begin(attribute.strings()), std::end(attribute.strings())};
default:
throw error::attribute::InvalidData{attribute.type()};
ONNX_INVALID_ATTR(attribute.type(), "STRING, STRINGS");
}
}

Expand Down Expand Up @@ -320,7 +303,7 @@ class Attribute {
if (is_tensor()) {
return Tensor{m_attribute_proto->t(), m_model_dir, m_mmap_cache};
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "TENSOR");
}

template <typename T, typename std::enable_if<std::is_same<T, std::vector<Tensor>>::value, bool>::type = true>
Expand All @@ -330,15 +313,15 @@ class Attribute {
} else if (is_tensor_array()) {
return get_tensor_array();
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "TENSOR, TENSORS");
}

template <typename T, typename std::enable_if<std::is_same<T, SparseTensor>::value, bool>::type = true>
T get_value() const {
if (is_sparse_tensor()) {
return SparseTensor{m_attribute_proto->sparse_tensor(), m_model_dir, m_mmap_cache};
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "SPARSE_TENSOR");
}

template <typename T, typename std::enable_if<std::is_same<T, std::vector<SparseTensor>>::value, bool>::type = true>
Expand All @@ -348,7 +331,7 @@ class Attribute {
} else if (is_sparse_tensor_array()) {
return get_sparse_tensor_array();
}
throw error::attribute::InvalidData{m_attribute_proto->type()};
ONNX_INVALID_ATTR(m_attribute_proto->type(), "SPARSE_TENSOR, SPARSE_TENSORS");
}

ov::Any get_any() const;
Expand Down
Loading
Loading