diff --git a/inference-engine/src/plugin_api/ie_ngraph_utils.hpp b/inference-engine/src/plugin_api/ie_ngraph_utils.hpp index 40904bb07215ca..48a9a026daceab 100644 --- a/inference-engine/src/plugin_api/ie_ngraph_utils.hpp +++ b/inference-engine/src/plugin_api/ie_ngraph_utils.hpp @@ -134,6 +134,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) { return Precision(Precision::BIN); case ::ngraph::element::Type_t::boolean: return Precision(Precision::BOOL); + case ::ngraph::element::Type_t::dynamic: + return Precision(Precision::UNSPECIFIED); default: IE_THROW() << "Incorrect precision " << precision.get_type_name() << "!"; return{}; } diff --git a/inference-engine/src/transformations/include/ngraph_ops/framework_node.hpp b/inference-engine/src/transformations/include/ngraph_ops/framework_node.hpp index 1a5729d8ecdff5..8abda399c9c049 100644 --- a/inference-engine/src/transformations/include/ngraph_ops/framework_node.hpp +++ b/inference-engine/src/transformations/include/ngraph_ops/framework_node.hpp @@ -55,7 +55,7 @@ class TRANSFORMATIONS_API FrameworkNode : public Op { public: NGRAPH_RTTI_DECLARATION; - explicit FrameworkNode(const OutputVector& inputs); + explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1); void validate_and_infer_types() override; diff --git a/inference-engine/src/transformations/src/ngraph_ops/framework_node.cpp b/inference-engine/src/transformations/src/ngraph_ops/framework_node.cpp index b25143c20f5aef..94d0008c11064e 100644 --- a/inference-engine/src/transformations/src/ngraph_ops/framework_node.cpp +++ b/inference-engine/src/transformations/src/ngraph_ops/framework_node.cpp @@ -10,8 +10,9 @@ using namespace ngraph; NGRAPH_RTTI_DEFINITION(op::FrameworkNode, "FrameworkNode", 0); -op::FrameworkNode::FrameworkNode(const OutputVector& inputs) +op::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size) : Op(inputs) { + set_output_size(output_size); constructor_validate_and_infer_types(); } diff --git a/inference-engine/src/transformations/src/transformations/serialize.cpp b/inference-engine/src/transformations/src/transformations/serialize.cpp index 1fd41881125b41..93f9c24e4b81bb 100644 --- a/inference-engine/src/transformations/src/transformations/serialize.cpp +++ b/inference-engine/src/transformations/src/transformations/serialize.cpp @@ -495,6 +495,7 @@ std::string get_opset_name( std::string get_precision_name(const ngraph::element::Type & elem_type) { switch (elem_type) { case ::ngraph::element::Type_t::undefined: + case ::ngraph::element::Type_t::dynamic: return "UNSPECIFIED"; case ::ngraph::element::Type_t::f16: return "FP16"; diff --git a/ngraph/frontend/onnx_import/CMakeLists.txt b/ngraph/frontend/onnx_import/CMakeLists.txt index 0ddb78ad071510..bb6a4e7ff99580 100644 --- a/ngraph/frontend/onnx_import/CMakeLists.txt +++ b/ngraph/frontend/onnx_import/CMakeLists.txt @@ -45,7 +45,7 @@ if(COMMAND ie_faster_build) ) endif() -target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder +target_link_libraries(onnx_importer PRIVATE onnx_common ngraph::builder inference_engine_transformations PUBLIC ngraph) target_include_directories(onnx_importer PUBLIC $ diff --git a/ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp b/ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp index c2a3e6b820cdbc..cb5d11fde31e5c 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp @@ -75,9 +75,8 @@ namespace ngraph bool has_attribute(const std::string& name) const; - Subgraph get_subgraph_from_attribute( - const std::string& name, - const std::map& carried_dependencies_map) const; + bool has_subgraph() const; + std::shared_ptr get_subgraph() const; template T get_attribute_value(const std::string& name, T default_value) const; diff --git a/ngraph/frontend/onnx_import/include/onnx_import/onnx_framework_node.hpp b/ngraph/frontend/onnx_import/include/onnx_import/onnx_framework_node.hpp new file mode 100644 index 00000000000000..bfa902a5ac449c --- /dev/null +++ b/ngraph/frontend/onnx_import/include/onnx_import/onnx_framework_node.hpp @@ -0,0 +1,100 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include +#include + +namespace ONNX_NAMESPACE +{ + // forward declaration + class ModelProto; +} // namespace ONNX_NAMESPACE + +namespace ngraph +{ + namespace onnx_import + { + class Model; + } + + namespace frontend + { + class ONNXFrameworkNode : public op::FrameworkNode + { + public: + NGRAPH_RTTI_DECLARATION; + + ONNXFrameworkNode(const onnx_import::Node& node) + : FrameworkNode(node.get_ng_inputs(), node.get_outputs_size()) + , m_node(node) + { + } + + ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs) + : FrameworkNode(inputs, node.get_outputs_size()) + , m_node(node) + { + } + + const onnx_import::Node& get_onnx_node() const { return m_node; } + + virtual std::shared_ptr + clone_with_new_inputs(const OutputVector& inputs) const override; + + virtual bool visit_attributes(AttributeVisitor& visitor) override + { + // TODO: implement reading as well, now it work for serialization only + std::string domain = m_node.domain(); + std::string op_type = m_node.op_type(); + visitor.on_attribute("ONNX_META_domain", domain); + visitor.on_attribute("ONNX_META_type", op_type); + return true; + } + + private: + onnx_import::Node m_node; + }; + + class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode + { + public: + NGRAPH_RTTI_DECLARATION; + + ONNXSubgraphFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs) + : ONNXFrameworkNode(node, inputs) + { + } + + void infer_inputs_from_parent() + { + get_onnx_node().get_subgraph()->infer_inputs_from_parent(); + } + + std::shared_ptr get_subgraph_body() const + { + auto subgraph = get_onnx_node().get_subgraph(); + return std::make_shared(subgraph->get_ng_outputs(), + subgraph->get_ng_parameters(), + subgraph->get_name()); + } + }; + + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/attribute.cpp b/ngraph/frontend/onnx_import/src/core/attribute.cpp index 8eaa8c93517d8e..1fd61931de9629 100644 --- a/ngraph/frontend/onnx_import/src/core/attribute.cpp +++ b/ngraph/frontend/onnx_import/src/core/attribute.cpp @@ -11,9 +11,7 @@ namespace ngraph { namespace onnx_import { - Subgraph Attribute::get_subgraph( - const Graph& parent_graph, - const std::map& carried_dependencies_map) const + Subgraph Attribute::get_subgraph(const Graph& parent_graph) const { if (m_attribute_proto->type() != ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH) { @@ -25,33 +23,6 @@ namespace ngraph const auto& graph = m_attribute_proto->g(); model_proto->mutable_graph()->CopyFrom(graph); - const std::size_t subgraph_inputs_count = - static_cast(model_proto->mutable_graph()->mutable_input()->size()); - // Use the `carried_dependencies_map` to infer the types for the subgraph inputs - for (const auto& carried_dependency : carried_dependencies_map) - { - if (carried_dependency.first >= subgraph_inputs_count) - { - NGRAPH_WARN << "Input with index: '" << carried_dependency.first - << "' was not found in the subgraph"; - } - else - { - const auto& parent_in = - parent_graph.get_ng_node_from_cache(carried_dependency.second); - const auto& carried_type = parent_in.get_element_type(); - auto subgraph_in = - model_proto->mutable_graph()->mutable_input(carried_dependency.first); - auto subgraph_in_tensor_type = - subgraph_in->mutable_type()->mutable_tensor_type(); - if (!subgraph_in_tensor_type->has_elem_type()) - { - subgraph_in_tensor_type->set_elem_type( - onnx_common::ng_to_onnx_data_type(carried_type)); - } - } - } - // set opset version and domain from the parent graph model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports()); auto model = common::make_unique(std::move(model_proto)); diff --git a/ngraph/frontend/onnx_import/src/core/attribute.hpp b/ngraph/frontend/onnx_import/src/core/attribute.hpp index bc192e7b392fcb..963dab22cb53de 100644 --- a/ngraph/frontend/onnx_import/src/core/attribute.hpp +++ b/ngraph/frontend/onnx_import/src/core/attribute.hpp @@ -316,9 +316,7 @@ namespace ngraph float get_float() const { return m_attribute_proto->f(); } int64_t get_integer() const { return m_attribute_proto->i(); } const std::string& get_string() const { return m_attribute_proto->s(); } - Subgraph get_subgraph( - const Graph& parent_graph, - const std::map& carried_dependencies_map) const; + Subgraph get_subgraph(const Graph& parent_graph) const; std::vector get_tensor_array() const { diff --git a/ngraph/frontend/onnx_import/src/core/graph.cpp b/ngraph/frontend/onnx_import/src/core/graph.cpp index c8f56327d6bb22..569d3849774859 100644 --- a/ngraph/frontend/onnx_import/src/core/graph.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph.cpp @@ -14,6 +14,7 @@ #include "ngraph/node.hpp" #include "ngraph/provenance.hpp" #include "onnx_import/core/node.hpp" +#include "onnx_import/onnx_framework_node.hpp" #include "utils/common.hpp" #include "utils/provenance_tag.hpp" @@ -55,25 +56,6 @@ namespace ngraph Graph::Graph(std::unique_ptr&& model) : Graph(std::move(model), common::make_unique()) { - // Remove dangling Parameters - for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();) - { - if ((*param_it)->get_output_target_inputs(0).size() == 0) - { - const auto& name = (*param_it)->get_friendly_name(); - auto out_it = std::find_if( - m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) { - return info.get_name() == name; - }); - if (out_it == m_outputs.end()) - { - m_cache->remove_node(name); - param_it = m_parameters.erase(param_it); - continue; - } - } - param_it++; - } } Graph::Graph(std::unique_ptr&& model, std::unique_ptr&& cache) @@ -174,14 +156,82 @@ namespace ngraph NGRAPH_CHECK(unknown_operators.empty(), "nGraph does not support the following ONNX operations: ", detail::to_string(unknown_operators)); + } + void Graph::convert_to_ngraph_nodes() + { // Process ONNX graph nodes, convert to nGraph nodes for (const auto& node_proto : m_model->get_graph().node()) { m_nodes.emplace_back(node_proto, *this); const Node& node{m_nodes.back()}; - + if (node.has_subgraph()) + { + auto subgraph = node.get_subgraph(); + auto body_func = subgraph->convert(); + } OutputVector ng_nodes{node.get_ng_nodes()}; + set_friendly_names(node, ng_nodes); + for (std::size_t i{0}; i < node.get_outputs_size(); ++i) + { + m_cache->emplace_node(node.output(i), std::move(ng_nodes.at(i))); + } + } + } + + void Graph::remove_dangling_parameters() + { + for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();) + { + if ((*param_it)->get_output_target_inputs(0).size() == 0) + { + const auto& name = (*param_it)->get_friendly_name(); + auto out_it = std::find_if( + m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) { + return info.get_name() == name; + }); + if (out_it == m_outputs.end()) + { + m_cache->remove_node(name); + param_it = m_parameters.erase(param_it); + continue; + } + } + param_it++; + } + } + + std::shared_ptr Graph::convert() + { + convert_to_ngraph_nodes(); + remove_dangling_parameters(); + return create_function(); + } + + void Graph::decode_to_framework_nodes() + { + // Process ONNX graph nodes, convert to nGraph nodes + for (const auto& node_proto : m_model->get_graph().node()) + { + m_nodes.emplace_back(node_proto, *this); + const Node& node{m_nodes.back()}; + std::shared_ptr framework_node; + if (node.has_subgraph()) + { + auto subgraph = node.get_subgraph(); + auto body_func = subgraph->decode(); + auto inputs = node.get_ng_inputs(); + for (const auto& input : subgraph->get_inputs_from_parent()) + inputs.push_back(input); + framework_node = + std::make_shared(node, inputs); + } + else + { + framework_node = std::make_shared(node); + } + OutputVector ng_nodes{framework_node->outputs()}; + set_friendly_names(node, ng_nodes); // Iterate over the number of outputs for given node in graph. // Some of them may be optional and trimmed. See: // https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs @@ -192,12 +242,24 @@ namespace ngraph } } - const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); } - bool Graph::is_node_in_cache(const std::string& name) const + std::shared_ptr Graph::create_function() + { + auto function = std::make_shared(get_ng_outputs(), m_parameters, get_name()); + for (std::size_t i{0}; i < function->get_output_size(); ++i) + { + function->get_output_op(i)->set_friendly_name(m_outputs.at(i).get_name()); + } + return function; + } + + std::shared_ptr Graph::decode() { - return m_cache->contains(name); + decode_to_framework_nodes(); + return create_function(); } + const GraphCache& Graph::get_graph_cache() const { return *m_cache.get(); } + Output Graph::get_ng_node_from_cache(const std::string& name) const { return m_cache->get_node(name); @@ -247,6 +309,12 @@ namespace ngraph set_friendly_names(onnx_node, ng_node_vector); add_provenance_tags(onnx_node, ng_node_vector); + for (std::size_t i{0}; i < onnx_node.get_outputs_size(); ++i) + { + auto ng_node = ng_node_vector.at(i); + m_cache->emplace_node(onnx_node.output(i), std::move(ng_node)); + } + return ng_node_vector; } @@ -323,9 +391,21 @@ namespace ngraph } Subgraph::Subgraph(std::unique_ptr&& model, const Graph& parent_graph) - : Graph( - std::move(model), - std::unique_ptr(new SubgraphCache(parent_graph.get_graph_cache()))) + : Graph(std::move(model), common::make_unique()) + , m_parent_graph_cache(&parent_graph.get_graph_cache()) + { + } + + Output Subgraph::get_ng_node_from_cache(const std::string& name) const + { + if (m_cache->contains(name)) + { + return m_cache->get_node(name); + } + return m_parent_graph_cache->get_node(name); + } + + void Subgraph::find_inputs_from_parent() { // find all nodes on edge parent graph-subgraph // (it means input of node from parent graph, output from subgraph) @@ -334,16 +414,16 @@ namespace ngraph int input_index = 0; for (const auto& in_name : node_proto.input()) { - if (m_cache->node_scope(in_name) == NodeScope::ParentGraph) + if (m_parent_graph_cache->contains(in_name)) { - const auto& from_parent_node = m_cache->get_node(in_name); + const auto& from_parent_node = m_parent_graph_cache->get_node(in_name); // constants are skipped if (!ngraph::is_type( from_parent_node.get_node_shared_ptr())) { for (const auto& out_name : node_proto.output()) { - if (m_cache->node_scope(out_name) == NodeScope::SubGraph) + if (m_cache->contains(out_name)) { auto out_node_to_replace_input = m_cache->get_node(out_name); auto new_param = std::make_shared( @@ -353,8 +433,10 @@ namespace ngraph out_node_to_replace_input.get_node() ->input(input_index) .replace_source_output(new_param); + m_parameter_to_parent_node_map.insert({new_param, in_name}); + m_cache->emplace_node(in_name, new_param); m_parameters.push_back(new_param); - m_outputs_from_parent.push_back(from_parent_node); + m_inputs_from_parent.push_back(in_name); } } } @@ -364,11 +446,39 @@ namespace ngraph } } - const std::vector> Subgraph::get_outputs_from_parent() const + std::shared_ptr Subgraph::convert() { - return m_outputs_from_parent; + convert_to_ngraph_nodes(); + find_inputs_from_parent(); + return create_function(); } + void Subgraph::decode_to_framework_nodes() + { + Graph::decode_to_framework_nodes(); + find_inputs_from_parent(); + } + + const std::vector> Subgraph::get_inputs_from_parent() const + { + OutputVector result; + for (const auto& name : m_inputs_from_parent) + { + result.push_back(m_parent_graph_cache->get_node(name)); + } + return result; + } + + void Subgraph::infer_inputs_from_parent() + { + for (auto& it : m_parameter_to_parent_node_map) + { + const auto& node = m_parent_graph_cache->get_node(it.second); + auto& parameter = it.first; + parameter->set_element_type(node.get_element_type()); + parameter->set_partial_shape(node.get_partial_shape()); + } + } } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/graph.hpp b/ngraph/frontend/onnx_import/src/core/graph.hpp index 6cbd880410984c..33c2be5d4d20e8 100644 --- a/ngraph/frontend/onnx_import/src/core/graph.hpp +++ b/ngraph/frontend/onnx_import/src/core/graph.hpp @@ -31,13 +31,14 @@ namespace ngraph Graph& operator=(const Graph&) = delete; Graph& operator=(Graph&&) = default; + virtual std::shared_ptr convert(); + std::shared_ptr decode(); const std::vector& get_nodes() const { return m_nodes; } const std::vector& get_inputs() const { return m_inputs; } const std::vector& get_outputs() const { return m_outputs; } OutputVector get_ng_outputs() const; const ParameterVector& get_ng_parameters() const { return m_parameters; } - bool is_node_in_cache(const std::string& name) const; - Output get_ng_node_from_cache(const std::string& name) const; + virtual Output get_ng_node_from_cache(const std::string& name) const; const std::string& get_name() const { return m_model->get_graph().name(); } OutputVector make_ng_nodes(const Node& onnx_node) const; const GraphCache& get_graph_cache() const; @@ -60,6 +61,11 @@ namespace ngraph const OutputVector& ng_node_vector) const; protected: + virtual void decode_to_framework_nodes(); + void convert_to_ngraph_nodes(); + void remove_dangling_parameters(); + std::shared_ptr create_function(); + ParameterVector m_parameters; std::unique_ptr m_model; std::unique_ptr m_cache; @@ -82,9 +88,11 @@ namespace ngraph /// \param[in] parent_graph The reference to the parent graph. Subgraph(std::unique_ptr&& model, const Graph& parent_graph); - /// \brief Return outputs which are on the edge the subgraph and the parent graph. + /// \brief Return nodes which are on the edge the subgraph and the parent graph. /// \return Vector of edge nodes from parent scope. - const std::vector> get_outputs_from_parent() const; + const std::vector> get_inputs_from_parent() const; + + std::shared_ptr convert() override; Subgraph() = delete; @@ -94,8 +102,17 @@ namespace ngraph Subgraph& operator=(const Subgraph&) = delete; Subgraph& operator=(Subgraph&&) = default; + Output get_ng_node_from_cache(const std::string& name) const override; + void infer_inputs_from_parent(); + private: - std::vector> m_outputs_from_parent; + void decode_to_framework_nodes() override; + void find_inputs_from_parent(); + + const GraphCache* m_parent_graph_cache; + std::vector m_inputs_from_parent; + std::unordered_map, std::string> + m_parameter_to_parent_node_map; }; inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) diff --git a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp index 9a0e0b59bbc42e..69593c062a3e69 100644 --- a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp @@ -39,55 +39,5 @@ namespace ngraph { return (m_graph_cache_map.count(name) > 0); } - - NodeScope GraphCache::node_scope(const std::string& name) const - { - return contains(name) ? NodeScope::ParentGraph : NodeScope::Lack; - } - - SubgraphCache::SubgraphCache(const GraphCache& parent_graph_cache) - : m_parent_graph_cache{&parent_graph_cache} - { - if (m_parent_graph_cache == nullptr) - { - throw ngraph_error("Parent graph cache is not initialized"); - } - } - - Output SubgraphCache::get_node(const std::string& name) const - { - // present in subgraph scope - if (GraphCache::contains(name)) - { - return GraphCache::get_node(name); - } - else // present in parent graph scope - { - return m_parent_graph_cache->get_node(name); - } - } - - bool SubgraphCache::contains(const std::string& name) const - { - // the node is in subgraph or in parent graph scope - return GraphCache::contains(name) || m_parent_graph_cache->contains(name); - } - - NodeScope SubgraphCache::node_scope(const std::string& name) const - { - if (GraphCache::contains(name)) - { - return NodeScope::SubGraph; - } - else if (m_parent_graph_cache->contains(name)) - { - return NodeScope::ParentGraph; - } - else - { - return NodeScope::Lack; - } - } - } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/graph_cache.hpp b/ngraph/frontend/onnx_import/src/core/graph_cache.hpp index a59af9b4a9f146..556811a91df326 100644 --- a/ngraph/frontend/onnx_import/src/core/graph_cache.hpp +++ b/ngraph/frontend/onnx_import/src/core/graph_cache.hpp @@ -14,17 +14,6 @@ namespace ngraph { namespace onnx_import { - /// \brief Enum which determines scope (visibility) of nodes in GraphCache. - enum class NodeScope - { - // in parent graph scope - ParentGraph = 1, - // in subgraph scope - SubGraph, - // not available at all - Lack - }; - /// \brief GraphCache stores and provides access to ONNX graph initializers. class GraphCache { @@ -58,58 +47,10 @@ namespace ngraph /// \return true if the node named `name` exist in the cache, false otherwise. virtual bool contains(const std::string& name) const; - /// \brief Return NodeScope enum which determines scope of the node. - /// \note If the method is called on GraphCache the ParentGraph enum - /// value is retunred always. - /// - /// \param[in] name The name of the node. - /// - /// \return SubGraph if node belongs to SubgraphCache, ParentGraph if - /// is avalible in parent_graph_cache, otherwise Lack - virtual NodeScope node_scope(const std::string& name) const; - virtual ~GraphCache() = default; private: std::map> m_graph_cache_map; }; - - class SubgraphCache : public GraphCache - { - public: - /// \brief Constructs a SubgraphCache class object. - /// - /// \param[in] parent_graph_cache The reference to the parent graph. - SubgraphCache(const GraphCache& parent_graph_cache); - - /// \brief Get the node from the cache (subgraph or parent graph) - /// - /// \note If the node is not found the ngraph_error exception is thrown. - /// - /// \param[in] name The name of the node. - /// - /// \return The node named `name` from subgraph (as present) or from parent graph. - Output get_node(const std::string& name) const override; - - /// \brief Return true if the node named `name` exist in the cache. - /// - /// \param[in] name The name of the node. - /// - /// \return true if the node named `name` exist in the cache - /// (subgraph or parent graph), false otherwise. - bool contains(const std::string& name) const override; - - /// \brief Return NodeScope enum which determines scope of the node. - /// - /// \param[in] name The name of the node. - /// - /// \return SubGraph if the node belongs to SubgraphCache, ParentGraph if - /// is avalible in parent_graph_cache, otherwise Lack - NodeScope node_scope(const std::string& name) const override; - - private: - const GraphCache* m_parent_graph_cache; - }; - } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/model.cpp b/ngraph/frontend/onnx_import/src/core/model.cpp index 452aea7b4775e4..2ddd3edac02e7a 100644 --- a/ngraph/frontend/onnx_import/src/core/model.cpp +++ b/ngraph/frontend/onnx_import/src/core/model.cpp @@ -6,6 +6,7 @@ #include "core/model.hpp" #include "ngraph/log.hpp" +#include "onnx_import/onnx_framework_node.hpp" #include "ops_bridge.hpp" namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/node.cpp b/ngraph/frontend/onnx_import/src/core/node.cpp index 1361e802bbff24..b6f2797263b384 100644 --- a/ngraph/frontend/onnx_import/src/core/node.cpp +++ b/ngraph/frontend/onnx_import/src/core/node.cpp @@ -26,6 +26,29 @@ namespace ngraph , m_graph{&graph} , m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())} , m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} + { + const auto it = + std::find_if(std::begin(m_attributes), + std::end(m_attributes), + [&](const Attribute& attribute) { return attribute.is_graph(); }); + m_has_subgraph = it != std::end(m_attributes); + if (m_has_subgraph) + { + m_subgraph = std::make_shared(it->get_subgraph(*m_graph)); + } + } + + Impl(const ONNX_NAMESPACE::NodeProto& node_proto, + const Graph& graph, + std::shared_ptr subgraph) + : m_node_proto{&node_proto} + , m_name{node_proto.has_name() ? node_proto.name() : ""} + , m_domain{get_node_domain(node_proto)} + , m_graph{&graph} + , m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())} + , m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} + , m_has_subgraph(subgraph != nullptr) + , m_subgraph(subgraph) { } @@ -44,9 +67,8 @@ namespace ngraph bool has_attribute(const std::string& name) const; - Subgraph get_subgraph_from_attribute( - const std::string& name, - const std::map& carried_dependencies_map) const; + bool has_subgraph() const; + std::shared_ptr get_subgraph() const; template T get_attribute_value(const std::string& name, T default_value) const; @@ -58,6 +80,8 @@ namespace ngraph const Graph& graph() const; private: + Subgraph get_subgraph_from_attribute(const std::string& name) const; + const ONNX_NAMESPACE::NodeProto* m_node_proto; std::string m_name; std::string m_domain; @@ -65,6 +89,9 @@ namespace ngraph std::vector m_attributes; std::vector> m_output_names; mutable std::string m_description; + + bool m_has_subgraph; + std::shared_ptr m_subgraph; }; const ONNX_NAMESPACE::NodeProto& Node::Impl::node_proto() const { return *m_node_proto; } @@ -94,9 +121,7 @@ namespace ngraph return it != std::end(m_attributes); } - Subgraph Node::Impl::get_subgraph_from_attribute( - const std::string& name, - const std::map& carried_dependencies_map) const + Subgraph Node::Impl::get_subgraph_from_attribute(const std::string& name) const { auto it = std::find_if( std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) { @@ -106,9 +131,13 @@ namespace ngraph { throw error::node::UnknownAttribute{this->name(), name}; } - return it->get_subgraph(graph(), carried_dependencies_map); + return it->get_subgraph(*m_graph); } + bool Node::Impl::has_subgraph() const { return m_has_subgraph; } + + std::shared_ptr Node::Impl::get_subgraph() const { return m_subgraph; } + template T Node::Impl::get_attribute_value(const std::string& name, T default_value) const { @@ -140,8 +169,7 @@ namespace ngraph template <> Subgraph Node::Impl::get_attribute_value(const std::string& name) const { - const std::map empty_map; - return get_subgraph_from_attribute(name, empty_map); + return get_subgraph_from_attribute(name); } OutputVector Node::Impl::get_ng_nodes(const Node& node) const @@ -196,7 +224,9 @@ namespace ngraph } Node::Node(const Node& other) - : m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph()}, + : m_pimpl{new Impl{other.m_pimpl->node_proto(), + other.m_pimpl->graph(), + other.get_subgraph()}, [](Impl* impl) { delete impl; }} { } @@ -219,12 +249,9 @@ namespace ngraph return m_pimpl->has_attribute(name); } - Subgraph Node::get_subgraph_from_attribute( - const std::string& name, - const std::map& carried_dependencies_map) const - { - return m_pimpl->get_subgraph_from_attribute(name, carried_dependencies_map); - } + bool Node::has_subgraph() const { return m_pimpl->has_subgraph(); } + + std::shared_ptr Node::get_subgraph() const { return m_pimpl->get_subgraph(); } std::vector Node::get_attribute_names() const { @@ -462,7 +489,6 @@ namespace ngraph { return m_pimpl->template get_attribute_value>(name); } - } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/null_node.hpp b/ngraph/frontend/onnx_import/src/core/null_node.hpp index c02a06ecfd2706..dd75770488c435 100644 --- a/ngraph/frontend/onnx_import/src/core/null_node.hpp +++ b/ngraph/frontend/onnx_import/src/core/null_node.hpp @@ -36,7 +36,10 @@ namespace ngraph public: static constexpr NodeTypeInfo type_info{"NullNode", 0}; const NodeTypeInfo& get_type_info() const override { return type_info; } - NullNode() = default; + NullNode() + : Node(1) + { + } virtual std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/ngraph/frontend/onnx_import/src/core/value_info.hpp b/ngraph/frontend/onnx_import/src/core/value_info.hpp index 67f2c5f7e2b779..76b3357c6ab3bb 100644 --- a/ngraph/frontend/onnx_import/src/core/value_info.hpp +++ b/ngraph/frontend/onnx_import/src/core/value_info.hpp @@ -19,20 +19,6 @@ namespace ngraph { namespace onnx_import { - namespace error - { - namespace value_info - { - struct unspecified_element_type : ngraph_error - { - unspecified_element_type() - : ngraph_error{"value info has no element type specified"} - { - } - }; - } // namespace value_info - } // namespace error - class ValueInfo { public: @@ -65,12 +51,12 @@ namespace ngraph const PartialShape& get_shape() const { return m_partial_shape; } const element::Type& get_element_type() const { - if (!m_value_info_proto->type().tensor_type().has_elem_type()) + if (m_value_info_proto->type().tensor_type().has_elem_type()) { - throw error::value_info::unspecified_element_type{}; + return common::get_ngraph_element_type( + m_value_info_proto->type().tensor_type().elem_type()); } - return common::get_ngraph_element_type( - m_value_info_proto->type().tensor_type().elem_type()); + return ngraph::element::dynamic; } std::shared_ptr diff --git a/ngraph/frontend/onnx_import/src/onnx_framework_node.cpp b/ngraph/frontend/onnx_import/src/onnx_framework_node.cpp new file mode 100644 index 00000000000000..bf52a1a2c0b8a0 --- /dev/null +++ b/ngraph/frontend/onnx_import/src/onnx_framework_node.cpp @@ -0,0 +1,34 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +namespace ngraph +{ + namespace frontend + { + NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1); + + std::shared_ptr + ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const + { + return std::make_shared(m_node, inputs); + } + + NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1); + + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/loop.cpp b/ngraph/frontend/onnx_import/src/op/loop.cpp index 23ded7464e2d3b..dbe4f68d8c983a 100644 --- a/ngraph/frontend/onnx_import/src/op/loop.cpp +++ b/ngraph/frontend/onnx_import/src/op/loop.cpp @@ -77,10 +77,18 @@ namespace ngraph loop_carried_dependencies[i].get_node()->get_friendly_name(); } - const Subgraph& body_graph{ - node.get_subgraph_from_attribute("body", loop_carried_dependencies_map)}; - auto body_outputs = body_graph.get_ng_outputs(); - const auto& body_inputs = body_graph.get_ng_parameters(); + auto body_graph = node.get_subgraph(); + auto body_outputs = body_graph->get_ng_outputs(); + const auto& body_inputs = body_graph->get_ng_parameters(); + + // Infer loop body inputs' element type based on carried dependencies + for (size_t i = 0; i < loop_carried_dependencies.size(); i++) + { + body_inputs[i + 2]->set_element_type( + loop_carried_dependencies[i].get_element_type()); + body_inputs[i + 2]->set_partial_shape( + loop_carried_dependencies[i].get_partial_shape()); + } // optional inputs Output trip_count; @@ -190,22 +198,22 @@ namespace ngraph final_values.push_back(loop->get_iter_value(*body_outputs_it++, -1)); } - const auto& outputs_from_parent = body_graph.get_outputs_from_parent(); + const auto& inputs_from_parent = body_graph->get_inputs_from_parent(); CHECK_VALID_NODE( node, static_cast(std::distance(body_inputs_it, body_inputs.end())) == - outputs_from_parent.size(), + inputs_from_parent.size(), "Expected number of invariant parameters is" - " not equal number of provided outputs from parent scope"); + " not equal number of provided inputs from parent scope"); // Set-up parameters from parent graph which are not changed during Loop's // iterations - for (auto out_from_parent_it = outputs_from_parent.begin(); + for (auto in_from_parent_it = inputs_from_parent.begin(); body_inputs_it != body_inputs.end() && - out_from_parent_it != outputs_from_parent.end(); - ++body_inputs_it, ++out_from_parent_it) + in_from_parent_it != inputs_from_parent.end(); + ++body_inputs_it, ++in_from_parent_it) { - loop->set_invariant_input(*body_inputs_it, *out_from_parent_it); + loop->set_invariant_input(*body_inputs_it, *in_from_parent_it); } // Set-up scan outputs diff --git a/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp b/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp index 74bb4a72d5c19c..8e60171a198c91 100644 --- a/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp +++ b/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp @@ -6,7 +6,9 @@ #include "core/graph.hpp" #include "core/model.hpp" +#include "core/null_node.hpp" #include "core/transform.hpp" +#include "onnx_import/onnx_framework_node.hpp" #include "onnx_import/utils/onnx_internal.hpp" namespace ngraph @@ -15,21 +17,81 @@ namespace ngraph { namespace detail { - std::shared_ptr - convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto) + void remove_dangling_parameters(std::shared_ptr& function) { - auto p_model_proto = common::make_unique(model_proto); - auto model = common::make_unique(std::move(p_model_proto)); + const auto parameters = function->get_parameters(); + for (auto parameter : parameters) + { + const auto parameter_users = parameter->get_users(); + // if a Parameter is connected to a ONNXFrameworkNode that was not converted + // during convert_function it means, this Parameter is dangling and we can + // remove it from function + const bool is_dangling_parameter = std::all_of( + parameter_users.begin(), + parameter_users.end(), + [](const std::shared_ptr& node) -> bool { + return std::dynamic_pointer_cast(node) != + nullptr; + }); + if (is_dangling_parameter) + { + function->remove_parameter(parameter); + } + } + } - Graph graph{std::move(model)}; - auto function = std::make_shared( - graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name()); - for (std::size_t i{0}; i < function->get_output_size(); ++i) + void remove_dangling_results(std::shared_ptr& function) + { + const auto results = function->get_results(); + for (auto result : results) { - function->get_output_op(i)->set_friendly_name( - graph.get_outputs().at(i).get_name()); + // we can remove Result from function if after function conversion, + // Result is connected to NullNode only + const auto result_inputs = result->input_values(); + const bool is_dangling_result = + std::all_of(result_inputs.begin(), + result_inputs.end(), + [](const Output& node) -> bool { + return ngraph::op::is_null(node); + }); + if (is_dangling_result) + { + function->remove_result(result); + } } - return function; + } + + void convert_decoded_function(std::shared_ptr function) + { + for (const auto& node : function->get_ordered_ops()) + { + if (auto raw_node = + std::dynamic_pointer_cast(node)) + { + if (auto subgraph_node = + std::dynamic_pointer_cast( + node)) + { + subgraph_node->infer_inputs_from_parent(); + convert_decoded_function(subgraph_node->get_subgraph_body()); + } + const auto& onnx_node = raw_node->get_onnx_node(); + OutputVector ng_nodes{onnx_node.get_ng_nodes()}; + if (ng_nodes.size() > raw_node->get_output_size()) + { + ng_nodes.resize(raw_node->get_output_size()); + } + replace_node(raw_node, ng_nodes); + } + else + { + // Have to revalidate node because new intpus can affect shape/type + // propagation for already translated nodes + node->revalidate_and_infer_types(); + } + } + remove_dangling_parameters(function); + remove_dangling_results(function); } std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, @@ -39,7 +101,10 @@ namespace ngraph transform::fixup_legacy_operators(model_proto); transform::update_external_data_paths(model_proto, model_path); - return detail::convert_to_ng_function(model_proto); + auto p_model_proto = common::make_unique(model_proto); + auto model = common::make_unique(std::move(p_model_proto)); + Graph graph{std::move(model)}; + return graph.convert(); } } // namespace detail } // namespace onnx_import diff --git a/ngraph/python/tests/test_onnx/test_ops_unary.py b/ngraph/python/tests/test_onnx/test_ops_unary.py index 01c9eeb9f55888..22d6b54f539c29 100644 --- a/ngraph/python/tests/test_onnx/test_ops_unary.py +++ b/ngraph/python/tests/test_onnx/test_ops_unary.py @@ -390,8 +390,7 @@ def test_cast_errors(): for name, value in zip(node.input, [input_data]) ] output_tensors = [ - make_tensor_value_info(name, onnx.TensorProto.FLOAT16, value.shape) - for name, value in zip(node.output, ()) + make_tensor_value_info(node.output[0], onnx.TensorProto.FLOAT16, input_data.shape) ] # type: ignore graph = make_graph([node], "compute_graph", input_tensors, output_tensors) @@ -406,8 +405,7 @@ def test_cast_errors(): for name, value in zip(node.input, [input_data]) ] output_tensors = [ - make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape) - for name, value in zip(node.output, ()) + make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape) ] # type: ignore graph = make_graph([node], "compute_graph", input_tensors, output_tensors) @@ -422,8 +420,7 @@ def test_cast_errors(): for name, value in zip(node.input, [input_data]) ] output_tensors = [ - make_tensor_value_info(name, onnx.TensorProto.INT32, value.shape) - for name, value in zip(node.output, ()) + make_tensor_value_info(node.output[0], onnx.TensorProto.INT32, input_data.shape) ] # type: ignore graph = make_graph([node], "compute_graph", input_tensors, output_tensors) @@ -438,8 +435,7 @@ def test_cast_errors(): for name, value in zip(node.input, [input_data]) ] output_tensors = [ - make_tensor_value_info(name, onnx.TensorProto.COMPLEX128, value.shape) - for name, value in zip(node.output, ()) + make_tensor_value_info(node.output[0], onnx.TensorProto.COMPLEX128, input_data.shape) ] # type: ignore graph = make_graph([node], "compute_graph", input_tensors, output_tensors) diff --git a/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt b/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt index 806f01ffd89ddf..cdbbf99419a241 100644 --- a/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt +++ b/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt @@ -2,7 +2,6 @@ ir_version: 7 producer_name: "backend-test" graph { node { - input: "target_shape" output: "output" op_type: "ConstantFill" attribute {