Skip to content

Commit

Permalink
Merge branch 'master' into river/remove_functional_test_utils_precisi…
Browse files Browse the repository at this point in the history
…on_utils.hpp
  • Loading branch information
riverlijunjie committed Feb 4, 2024
2 parents b5800b5 + b60526c commit d9faef5
Show file tree
Hide file tree
Showing 422 changed files with 3,704 additions and 5,943 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,33 @@

#pragma once

#include "openvino/core/deprecated.hpp"
#include "openvino/frontend/extension/conversion.hpp"
#include "openvino/frontend/node_context.hpp"
#include "openvino/frontend/onnx/visibility.hpp"

OPENVINO_SUPPRESS_DEPRECATED_START
namespace ngraph {
namespace onnx_import {
class Node;
}
} // namespace ngraph

namespace ov {
namespace frontend {
namespace onnx {
class Node;

class ONNX_FRONTEND_API NodeContext : public ov::frontend::NodeContext {
public:
using Ptr = std::shared_ptr<NodeContext>;
explicit NodeContext(const ngraph::onnx_import::Node& context);
explicit NodeContext(const ov::frontend::onnx::Node& context);
size_t get_input_size() const override;

Output<ov::Node> get_input(int port_idx) const override;

ov::Any get_attribute_as_any(const std::string& name) const override;

protected:
const ngraph::onnx_import::Node& m_context;
const ov::frontend::onnx::Node& m_context;
ov::OutputVector m_inputs;

private:
ov::Any apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const override;
};
using CreatorFunction = std::function<ov::OutputVector(const ngraph::onnx_import::Node&)>;
using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::onnx::Node&)>;
} // namespace onnx
} // namespace frontend
} // namespace ov
OPENVINO_SUPPRESS_DEPRECATED_END
15 changes: 8 additions & 7 deletions src/frontends/onnx/frontend/src/core/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
#include "core/graph.hpp"
#include "core/model.hpp"

namespace ngraph {
namespace onnx_import {
namespace ov {
namespace frontend {
namespace onnx {
Subgraph Attribute::get_subgraph(Graph* parent_graph) const {
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) {
if (m_attribute_proto->type() != AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH) {
ONNX_INVALID_ATTR(m_attribute_proto->type(), "GRAPH");
}

auto model_proto = std::make_shared<ONNX_NAMESPACE::ModelProto>();
auto model_proto = std::make_shared<ModelProto>();

const auto& graph = m_attribute_proto->g();
model_proto->mutable_graph()->CopyFrom(graph);
Expand Down Expand Up @@ -69,6 +70,6 @@ ov::Any Attribute::get_any() const {
}
}

} // namespace onnx_import

} // namespace ngraph
} // namespace onnx
} // namespace frontend
} // namespace ov
124 changes: 59 additions & 65 deletions src/frontends/onnx/frontend/src/core/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,87 +10,82 @@
#include "core/tensor.hpp"
#include "openvino/core/except.hpp"

namespace ngraph {
namespace onnx_import {
namespace ov {
namespace frontend {
namespace onnx {
// forward declarations
class Graph;
class Subgraph;
class Model;

// Detecting automatically the underlying type used to store the information
// for data type of values an attribute is holding. A bug was discovered in
// protobuf which forced ONNX team to switch from `enum AttributeProto_AttributeType`
// to `int32` in order to workaround the bug. This line allows using both versions
// of ONNX generated wrappers.
using AttributeProto_AttributeType = decltype(ONNX_NAMESPACE::AttributeProto{}.type());
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::AttributeProto_AttributeType;
using ::ONNX_NAMESPACE::AttributeProto_AttributeType_Name;

namespace detail {
namespace attribute {
template <typename T>
inline T get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline T get_value(const AttributeProto& attribute) {
OPENVINO_THROW("Unsupported attribute type");
}

#define ONNX_INVALID_ATTR(attr, expected) \
OPENVINO_THROW("Invalid attribute type ", \
ONNX_NAMESPACE::AttributeProto_AttributeType_Name(attr), \
" expected: ", \
expected)
#define ONNX_INVALID_ATTR(attr, expected) \
OPENVINO_THROW("Invalid attribute type ", AttributeProto_AttributeType_Name(attr), " expected: ", expected)

template <>
inline float get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline float get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return static_cast<float>(attribute.i());
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT:
return attribute.f();
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

template <>
inline std::vector<float> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline std::vector<float> get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return {static_cast<float>(attribute.i())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT:
return {attribute.f()};
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
}
}

template <>
inline double get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline double get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT:
return static_cast<double>(attribute.f());
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return static_cast<double>(attribute.i());
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, FLOAT");
}
}

template <>
inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline std::vector<double> get_value(const AttributeProto& attribute) {
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable : 4244)
#endif
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return {static_cast<double>(attribute.i())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT:
return {static_cast<double>(attribute.f())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS, FLOAT, FLOATS");
Expand All @@ -101,89 +96,88 @@ inline std::vector<double> get_value(const ONNX_NAMESPACE::AttributeProto& attri
}

template <>
inline std::size_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
inline std::size_t get_value(const AttributeProto& attribute) {
if (attribute.type() != AttributeProto_AttributeType::AttributeProto_AttributeType_INT) {
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return static_cast<std::size_t>(attribute.i());
}

template <>
inline std::vector<std::size_t> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline std::vector<std::size_t> get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return {static_cast<std::size_t>(attribute.i())};
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline int64_t get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) {
inline int64_t get_value(const AttributeProto& attribute) {
if (attribute.type() != AttributeProto_AttributeType::AttributeProto_AttributeType_INT) {
ONNX_INVALID_ATTR(attribute.type(), "INT");
}
return attribute.i();
}

template <>
inline std::vector<int64_t> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline std::vector<int64_t> get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INT:
return {attribute.i()};
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default:
ONNX_INVALID_ATTR(attribute.type(), "INT, INTS");
}
}

template <>
inline std::string get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
if (attribute.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_STRING) {
inline std::string get_value(const AttributeProto& attribute) {
if (attribute.type() != AttributeProto_AttributeType::AttributeProto_AttributeType_STRING) {
ONNX_INVALID_ATTR(attribute.type(), "STRING");
}
return attribute.s();
}

template <>
inline std::vector<std::string> get_value(const ONNX_NAMESPACE::AttributeProto& attribute) {
inline std::vector<std::string> get_value(const AttributeProto& attribute) {
switch (attribute.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRING:
case AttributeProto_AttributeType::AttributeProto_AttributeType_STRING:
return {attribute.s()};
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS:
case AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS:
return {std::begin(attribute.strings()), std::end(attribute.strings())};
default:
ONNX_INVALID_ATTR(attribute.type(), "STRING, STRINGS");
}
}

} // namespace attribute

} // namespace detail

class Attribute {
public:
enum class Type {
undefined = ONNX_NAMESPACE::AttributeProto_AttributeType_UNDEFINED,
float_point = ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT,
integer = ONNX_NAMESPACE::AttributeProto_AttributeType_INT,
string = ONNX_NAMESPACE::AttributeProto_AttributeType_STRING,
tensor = ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR,
graph = ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH,
sparse_tensor = ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR,
float_point_array = ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS,
integer_array = ONNX_NAMESPACE::AttributeProto_AttributeType_INTS,
string_array = ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS,
tensor_array = ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS,
sparse_tensor_array = ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS,
graph_array = ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS
undefined = AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED,
float_point = AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT,
integer = AttributeProto_AttributeType::AttributeProto_AttributeType_INT,
string = AttributeProto_AttributeType::AttributeProto_AttributeType_STRING,
tensor = AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR,
graph = AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH,
sparse_tensor = AttributeProto_AttributeType::AttributeProto_AttributeType_SPARSE_TENSOR,
float_point_array = AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS,
integer_array = AttributeProto_AttributeType::AttributeProto_AttributeType_INTS,
string_array = AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS,
tensor_array = AttributeProto_AttributeType::AttributeProto_AttributeType_TENSORS,
sparse_tensor_array = AttributeProto_AttributeType::AttributeProto_AttributeType_SPARSE_TENSORS,
graph_array = AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPHS
};

Attribute() = delete;
Attribute(const ONNX_NAMESPACE::AttributeProto& attribute_proto,
Attribute(const AttributeProto& attribute_proto,
const std::string& model_dir,
detail::MappedMemoryHandles mmap_cache)
: m_attribute_proto{&attribute_proto},
Expand Down Expand Up @@ -285,7 +279,7 @@ class Attribute {
return {std::begin(m_attribute_proto->strings()), std::end(m_attribute_proto->strings())};
}

/* explicit */ operator ONNX_NAMESPACE::AttributeProto_AttributeType() const {
/* explicit */ operator AttributeProto_AttributeType() const {
return m_attribute_proto->type();
}

Expand Down Expand Up @@ -337,11 +331,11 @@ class Attribute {
ov::Any get_any() const;

private:
const ONNX_NAMESPACE::AttributeProto* m_attribute_proto;
const AttributeProto* m_attribute_proto;
std::string m_model_dir;
detail::MappedMemoryHandles m_mmap_cache;
};

} // namespace onnx_import

} // namespace ngraph
} // namespace onnx
} // namespace frontend
} // namespace ov
Loading

0 comments on commit d9faef5

Please sign in to comment.