Skip to content

Commit

Permalink
Integration ONNX Editor with FE API (openvinotoolkit#6773)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Bencer authored and andrei-cv committed Aug 30, 2021
1 parent 84860c7 commit 0d25a5a
Show file tree
Hide file tree
Showing 14 changed files with 1,207 additions and 94 deletions.
2 changes: 1 addition & 1 deletion ngraph/frontend/frontend_manager/src/frontend_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ void InputModel::set_partial_shape(Place::Ptr place, const ngraph::PartialShape&

ngraph::PartialShape InputModel::get_partial_shape(Place::Ptr place) const
{
FRONT_END_NOT_IMPLEMENTED(set_partial_shape);
FRONT_END_NOT_IMPLEMENTED(get_partial_shape);
}

void InputModel::set_element_type(Place::Ptr place, const ngraph::element::Type&)
Expand Down
112 changes: 100 additions & 12 deletions ngraph/frontend/onnx/frontend/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,143 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "input_model.hpp"
#include <frontend_manager/frontend_exceptions.hpp>
#include <input_model.hpp>
#include <place.hpp>
#include "place.hpp"

using namespace ngraph;
using namespace ngraph::frontend;

InputModelONNX::InputModelONNX(const std::string& path)
: m_editor(path)
: m_editor{std::make_shared<onnx_editor::ONNXModelEditor>(path)}
{
}

std::vector<Place::Ptr> InputModelONNX::get_inputs() const
{
auto inputs = m_editor.model_inputs();
std::vector<Place::Ptr> ret;
ret.reserve(inputs.size());
const auto& inputs = m_editor->model_inputs();
std::vector<Place::Ptr> in_places;
in_places.reserve(inputs.size());
for (const auto& input : inputs)
{
ret.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
in_places.push_back(std::make_shared<PlaceTensorONNX>(input, m_editor));
}
return ret;
return in_places;
}

std::vector<Place::Ptr> InputModelONNX::get_outputs() const
{
const auto& outputs = m_editor->model_outputs();
std::vector<Place::Ptr> out_places;
out_places.reserve(outputs.size());
for (const auto& output : outputs)
{
out_places.push_back(std::make_shared<PlaceTensorONNX>(output, m_editor));
}
return out_places;
}

Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_name) const
{
NGRAPH_CHECK(m_editor->is_correct_tensor_name(tensor_name),
"The tensor with name: " + tensor_name + " does not exist in the graph");
return std::make_shared<PlaceTensorONNX>(tensor_name, m_editor);
}

Place::Ptr
InputModelONNX::get_place_by_operation_name_and_input_port(const std::string& operation_name,
int input_port_index)
{
const auto edge =
m_editor->find_input_edge(onnx_editor::EditorNode(operation_name), input_port_index);
return std::make_shared<PlaceInputEdgeONNX>(edge, m_editor);
}

void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape)
{
std::map<std::string, ngraph::PartialShape> m;
m[place->get_names()[0]] = shape;
m_editor.set_input_shapes(m);
m_editor->set_input_shapes(m);
}

ngraph::PartialShape InputModelONNX::get_partial_shape(Place::Ptr place) const
{
return m_editor->get_tensor_shape(place->get_names().at(0));
}

void InputModelONNX::set_element_type(Place::Ptr place, const ngraph::element::Type& type)
{
std::map<std::string, ngraph::element::Type_t> m;
m[place->get_names()[0]] = type;
m_editor.set_input_types(m);
m_editor->set_input_types(m);
}

std::shared_ptr<Function> InputModelONNX::decode()
{
return m_editor.decode();
return m_editor->decode();
}

std::shared_ptr<Function> InputModelONNX::convert()
{
return m_editor.get_function();
return m_editor->get_function();
}

// Editor features
void InputModelONNX::override_all_outputs(const std::vector<Place::Ptr>& outputs)
{
extract_subgraph({}, outputs);
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
"Unexpected number of outputs after override_all_outputs");
NGRAPH_CHECK(std::all_of(std::begin(outputs),
std::end(outputs),
[](const Place::Ptr& place) { return place->is_output(); }),
"Not all provided arguments of override_all_outputs are new outputs of the model");
}

void InputModelONNX::override_all_inputs(const std::vector<Place::Ptr>& inputs)
{
const auto outputs_before_extraction = m_editor->model_outputs();
extract_subgraph({inputs}, {});
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
std::end(outputs_before_extraction),
std::begin(m_editor->model_outputs())),
"All outputs should be preserved after override_all_inputs. Provided inputs does "
"not satisfy all outputs");
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
"Unexpected number of inputs after override_all_inputs");
}

void InputModelONNX::extract_subgraph(const std::vector<Place::Ptr>& inputs,
const std::vector<Place::Ptr>& outputs)
{
std::vector<onnx_editor::InputEdge> onnx_inputs;
onnx_inputs.reserve(inputs.size());
for (const auto& input : inputs)
{
if (const auto input_port = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(input))
{
onnx_inputs.push_back(input_port->get_input_edge());
}
else if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(input))
{
auto name = tensor->get_names()[0];
const auto consumers = m_editor->find_output_consumers(name);
std::transform(std::begin(consumers),
std::end(consumers),
std::back_inserter(onnx_inputs),
[](const onnx_editor::InputEdge& edge) { return edge; });
}
}

std::vector<onnx_editor::OutputEdge> onnx_outputs;
onnx_outputs.reserve(outputs.size());
for (const auto& output : outputs)
{
const auto output_port = output->get_producing_port();
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
NGRAPH_CHECK(onnx_output_edge,
"Non-onnx output place was passed as extraction subgraph argument");
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
}
m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs);
}
12 changes: 11 additions & 1 deletion ngraph/frontend/onnx/frontend/src/input_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@ namespace ngraph
InputModelONNX(const std::string& path);

std::vector<Place::Ptr> get_inputs() const override;
std::vector<Place::Ptr> get_outputs() const override;
Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
Place::Ptr get_place_by_operation_name_and_input_port(const std::string& operation_name,
int input_port_index) override;
void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override;
ngraph::PartialShape get_partial_shape(Place::Ptr place) const override;
void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override;

std::shared_ptr<Function> decode();
std::shared_ptr<Function> convert();

// Editor features
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override;
void override_all_inputs(const std::vector<Place::Ptr>& inputs) override;
void extract_subgraph(const std::vector<Place::Ptr>& inputs,
const std::vector<Place::Ptr>& outputs) override;

private:
onnx_editor::ONNXModelEditor m_editor;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
};

} // namespace frontend
Expand Down
134 changes: 134 additions & 0 deletions ngraph/frontend/onnx/frontend/src/place.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "place.hpp"
#include <frontend_manager/frontend_exceptions.hpp>

using namespace ngraph;
using namespace ngraph::frontend;

PlaceInputEdgeONNX::PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge}
, m_editor{editor}
{
}

onnx_editor::InputEdge PlaceInputEdgeONNX::get_input_edge() const
{
return m_edge;
}

bool PlaceInputEdgeONNX::is_input() const
{
return m_editor->is_input(m_edge);
}

bool PlaceInputEdgeONNX::is_output() const
{
return false;
}

bool PlaceInputEdgeONNX::is_equal(Place::Ptr another) const
{
if (const auto in_edge = std::dynamic_pointer_cast<PlaceInputEdgeONNX>(another))
{
const auto& editor_edge = in_edge->get_input_edge();
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
(editor_edge.m_port_idx == m_edge.m_port_idx);
}
return false;
}

PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_edge{edge}
, m_editor{editor}
{
}

onnx_editor::OutputEdge PlaceOutputEdgeONNX::get_output_edge() const
{
return m_edge;
}

bool PlaceOutputEdgeONNX::is_input() const
{
return false;
}

bool PlaceOutputEdgeONNX::is_output() const
{
return m_editor->is_output(m_edge);
}

bool PlaceOutputEdgeONNX::is_equal(Place::Ptr another) const
{
if (const auto out_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(another))
{
const auto& editor_edge = out_edge->get_output_edge();
return (editor_edge.m_node_idx == m_edge.m_node_idx) &&
(editor_edge.m_port_idx == m_edge.m_port_idx);
}
return false;
}

PlaceTensorONNX::PlaceTensorONNX(const std::string& name,
std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_name(name)
, m_editor(editor)
{
}

std::vector<std::string> PlaceTensorONNX::get_names() const
{
return {m_name};
}

Place::Ptr PlaceTensorONNX::get_producing_port() const
{
return std::make_shared<PlaceOutputEdgeONNX>(m_editor->find_output_edge(m_name), m_editor);
}

std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_ports() const
{
std::vector<Place::Ptr> ret;
auto edges = m_editor->find_output_consumers(m_name);
std::transform(edges.begin(),
edges.end(),
std::back_inserter(ret),
[this](const onnx_editor::InputEdge& edge) {
return std::make_shared<PlaceInputEdgeONNX>(edge, this->m_editor);
});
return ret;
}

Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const
{
return std::make_shared<PlaceInputEdgeONNX>(
m_editor->find_input_edge(onnx_editor::EditorOutput(m_name),
onnx_editor::EditorInput(input_port_index)),
m_editor);
}

bool PlaceTensorONNX::is_input() const
{
const auto inputs = m_editor->model_inputs();
return std::find(std::begin(inputs), std::end(inputs), m_name) != std::end(inputs);
}

bool PlaceTensorONNX::is_output() const
{
const auto outputs = m_editor->model_outputs();
return std::find(std::begin(outputs), std::end(outputs), m_name) != std::end(outputs);
}

bool PlaceTensorONNX::is_equal(Place::Ptr another) const
{
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(another))
{
return m_name == tensor->get_names().at(0);
}
return false;
}
Loading

0 comments on commit 0d25a5a

Please sign in to comment.