Skip to content

Commit

Permalink
[ONNX] Split importing model to two phases: decode and convert (openv…
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztabaka authored and akuporos committed Sep 29, 2021
1 parent 1a27c18 commit 5109de4
Show file tree
Hide file tree
Showing 22 changed files with 461 additions and 253 deletions.
2 changes: 2 additions & 0 deletions inference-engine/src/plugin_api/ie_ngraph_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
return Precision(Precision::BIN);
case ::ngraph::element::Type_t::boolean:
return Precision(Precision::BOOL);
case ::ngraph::element::Type_t::dynamic:
return Precision(Precision::UNSPECIFIED);
default:
IE_THROW() << "Incorrect precision " << precision.get_type_name() << "!"; return{};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TRANSFORMATIONS_API FrameworkNode : public Op {
public:
NGRAPH_RTTI_DECLARATION;

explicit FrameworkNode(const OutputVector& inputs);
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1);

void validate_and_infer_types() override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ using namespace ngraph;

NGRAPH_RTTI_DEFINITION(op::FrameworkNode, "FrameworkNode", 0);

op::FrameworkNode::FrameworkNode(const OutputVector& inputs)
op::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size)
: Op(inputs) {
set_output_size(output_size);
constructor_validate_and_infer_types();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ std::string get_opset_name(
std::string get_precision_name(const ngraph::element::Type & elem_type) {
switch (elem_type) {
case ::ngraph::element::Type_t::undefined:
case ::ngraph::element::Type_t::dynamic:
return "UNSPECIFIED";
case ::ngraph::element::Type_t::f16:
return "FP16";
Expand Down
2 changes: 1 addition & 1 deletion ngraph/frontend/onnx_import/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ if(COMMAND ie_faster_build)
)
endif()

target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder
target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder inference_engine_transformations
PUBLIC ngraph)

target_include_directories(onnx_importer PUBLIC $<BUILD_INTERFACE:${ONNX_IMPORT_INCLUDE_DIR}>
Expand Down
5 changes: 2 additions & 3 deletions ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ namespace ngraph

bool has_attribute(const std::string& name) const;

Subgraph get_subgraph_from_attribute(
const std::string& name,
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
bool has_subgraph() const;
std::shared_ptr<Subgraph> get_subgraph() const;

template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <core/graph.hpp>
#include <ngraph/visibility.hpp>
#include <ngraph_ops/framework_node.hpp>
#include <onnx_import/core/node.hpp>

namespace ONNX_NAMESPACE
{
// forward declaration
class ModelProto;
} // namespace ONNX_NAMESPACE

namespace ngraph
{
namespace onnx_import
{
class Model;
}

namespace frontend
{
class ONNXFrameworkNode : public op::FrameworkNode
{
public:
NGRAPH_RTTI_DECLARATION;

ONNXFrameworkNode(const onnx_import::Node& node)
: FrameworkNode(node.get_ng_inputs(), node.get_outputs_size())
, m_node(node)
{
}

ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
: FrameworkNode(inputs, node.get_outputs_size())
, m_node(node)
{
}

const onnx_import::Node& get_onnx_node() const { return m_node; }

virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;

virtual bool visit_attributes(AttributeVisitor& visitor) override
{
// TODO: implement reading as well, now it work for serialization only
std::string domain = m_node.domain();
std::string op_type = m_node.op_type();
visitor.on_attribute("ONNX_META_domain", domain);
visitor.on_attribute("ONNX_META_type", op_type);
return true;
}

private:
onnx_import::Node m_node;
};

class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode
{
public:
NGRAPH_RTTI_DECLARATION;

ONNXSubgraphFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
: ONNXFrameworkNode(node, inputs)
{
}

void infer_inputs_from_parent()
{
get_onnx_node().get_subgraph()->infer_inputs_from_parent();
}

std::shared_ptr<Function> get_subgraph_body() const
{
auto subgraph = get_onnx_node().get_subgraph();
return std::make_shared<Function>(subgraph->get_ng_outputs(),
subgraph->get_ng_parameters(),
subgraph->get_name());
}
};

} // namespace frontend
} // namespace ngraph
31 changes: 1 addition & 30 deletions ngraph/frontend/onnx_import/src/core/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ namespace ngraph
{
namespace onnx_import
{
Subgraph Attribute::get_subgraph(
const Graph& parent_graph,
const std::map<std::size_t, std::string>& carried_dependencies_map) const
Subgraph Attribute::get_subgraph(const Graph& parent_graph) const
{
if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH)
{
Expand All @@ -25,33 +23,6 @@ namespace ngraph
const auto& graph = m_attribute_proto->g();
model_proto->mutable_graph()->CopyFrom(graph);

const std::size_t subgraph_inputs_count =
static_cast<size_t>(model_proto->mutable_graph()->mutable_input()->size());
// Use the `carried_dependencies_map` to infer the types for the subgraph inputs
for (const auto& carried_dependency : carried_dependencies_map)
{
if (carried_dependency.first >= subgraph_inputs_count)
{
NGRAPH_WARN << "Input with index: '" << carried_dependency.first
<< "' was not found in the subgraph";
}
else
{
const auto& parent_in =
parent_graph.get_ng_node_from_cache(carried_dependency.second);
const auto& carried_type = parent_in.get_element_type();
auto subgraph_in =
model_proto->mutable_graph()->mutable_input(carried_dependency.first);
auto subgraph_in_tensor_type =
subgraph_in->mutable_type()->mutable_tensor_type();
if (!subgraph_in_tensor_type->has_elem_type())
{
subgraph_in_tensor_type->set_elem_type(
onnx_common::ng_to_onnx_data_type(carried_type));
}
}
}

// set opset version and domain from the parent graph
model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports());
auto model = common::make_unique<Model>(std::move(model_proto));
Expand Down
4 changes: 1 addition & 3 deletions ngraph/frontend/onnx_import/src/core/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Subgraph get_subgraph(
const Graph& parent_graph,
const std::map<std::size_t, std::string>& carried_dependencies_map) const;
Subgraph get_subgraph(const Graph& parent_graph) const;

std::vector<Tensor> get_tensor_array() const
{
Expand Down
Loading

0 comments on commit 5109de4

Please sign in to comment.