Skip to content

Commit

Permalink
Merge pull request #22 from mvafin/mvafin/pt_fe/unconv_loop
Browse files Browse the repository at this point in the history
Fix conversion of loops as PtFrameworkNode
  • Loading branch information
slyalin authored Oct 26, 2022
2 parents ba3bde8 + 0f763df commit 98aa401
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 31 deletions.
21 changes: 11 additions & 10 deletions src/frontends/pytorch/src/pt_framework_node.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <openvino/op/util/framework_node.hpp>

#include "utils.hpp"

#pragma once
Expand Down Expand Up @@ -86,12 +87,12 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {

bool visit_attributes(AttributeVisitor& visitor) override {
bool parent_visit_result = FrameworkNode::visit_attributes(visitor);
// TODO: serialize bodies and descriptors
/*for (size_t i = 0; i < m_bodies.size(); ++i) {
//visitor.on_attribute("body", m_bodies[i]);
//visitor.on_attribute("input_descriptions", m_input_descriptions[i]);
//visitor.on_attribute("output_descriptions", m_output_descriptions[i]);
}*/
// TODO: correctly serialize bodies and descriptors. Only 1st body information can be serialized.
for (size_t i = 0; i < m_bodies.size(); ++i) {
visitor.on_attribute("body" + std::to_string(i), m_bodies[i]);
//visitor.on_attribute("input_descriptions" + std::to_string(i), m_input_descriptions[i]);
// visitor.on_attribute("output_descriptions", m_output_descriptions[i]);
}
return parent_visit_result;
}

Expand Down Expand Up @@ -119,16 +120,16 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {
for (const auto& output_description : m_output_descriptions[i]) {
auto index = output_description->m_output_index;

const auto& body_value = m_bodies[i]->get_results().at(output_description->m_body_value_index)->input_value(0).get_tensor();
const auto& body_value =
m_bodies[i]->get_results().at(output_description->m_body_value_index)->input_value(0).get_tensor();

if (auto body_output_description =
ov::as_type_ptr<op::v0::TensorIterator::BodyOutputDescription>(output_description)) {
const ov::PartialShape& ps = body_value.get_partial_shape();
auto et = body_value.get_element_type();
if(et == element::custom) {
if (et == element::custom) {
output(index).get_tensor().set_custom_element_type(body_value.get_custom_element_type());
}
else {
} else {
set_output_type(index, et, ps);
}
}
Expand Down
83 changes: 62 additions & 21 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context,

OutputVector make_framework_node(NodeContext* context) {
auto schema = context->get_schema();
// TODO: properly process schema to get the actual position of mutable input
// Hack. Can indicate mutable inputs, but can it be reliable?
if (schema.find('!') != std::string::npos) {
// Hack. Can indicate mutable inputs, but can it be reliable?
// We create additional output for such nodes. It contains new tensor that represents input that was changed.
auto fw_node =
std::make_shared<PtFrameworkNode>(context->get_decoder(), context->inputs(), context->num_of_outputs() + 1);
Expand All @@ -109,9 +110,11 @@ OutputVector make_framework_node(NodeContext* context) {
fw_node->set_friendly_name(context->get_op_type());

std::map<size_t, ParameterVector> inputs_map;
std::map<size_t, ResultVector> extra_outputs_map;
std::set<size_t> input_idxs; // initial inputs
// We need to remember initial inputs to be able to find extra inputs to body that were created to propagate
// external context
int num_body_outs = 0;
for (size_t i = 0; i < context->get_decoder()->get_subgraph_size(); ++i) {
auto subgraph_decoder = context->get_decoder()->get_subgraph_decoder(i);
auto inputs = subgraph_decoder->inputs();
Expand All @@ -123,6 +126,23 @@ OutputVector make_framework_node(NodeContext* context) {
size_t input_idx = (size_t)std::stoll(name);
inputs_map[input_idx].push_back(param);
}
auto body_outputs = subgraph_decoder->outputs();
if (i == 0) {
num_body_outs = body_outputs.size();
} else {
FRONT_END_OP_CONVERSION_CHECK(
num_body_outs == body_outputs.size(),
"Number of outputs of this body is different from number of outputs of first body");
}
// Some bodies may have mutated inputs which we need to propagate to external context
auto body_results = body->get_results();
for (int i = num_body_outs; i < body_results.size(); i++) {
auto name = body_results[i]->input(0).get_tensor().get_any_name();
size_t out_idx = (size_t)std::stoll(name);
FRONT_END_OP_CONVERSION_CHECK(extra_outputs_map.count(out_idx) == 0,
"More then one body output with same tensor name.");
extra_outputs_map[out_idx].push_back(body_results[i]);
}
}
// Connect inputs with external context
for (const auto& input : inputs_map) {
Expand All @@ -136,10 +156,34 @@ OutputVector make_framework_node(NodeContext* context) {
}
}
}
// We do not connect body outputs. Depending from the kind of operation it can be done differently. Unconnected
// outputs from body should not invalidate the graph or result in deletion of some nodes, because outputs specified
// by pytorch would be connected in the outer scope.
return context->mark_node(fw_node)->outputs();
// Number of body outputs can be higher then number of pt node outputs, e.g. in case of loop first body output is
// condition, we have to skip such outputs
int num_skip_body_outputs =
num_body_outs > context->num_of_outputs() ? num_body_outs - context->num_of_outputs() : 0;
// We need to reduce number of outputs, because some outputs are outputs from body
fw_node->set_output_size(context->num_of_outputs() - num_body_outs + num_skip_body_outputs);
OutputVector res(context->mark_node(fw_node)->outputs());
if (fw_node->get_internal_subgraphs_size() > 0) {
auto first_body_results = fw_node->get_function(0)->get_results();
std::vector<ResultVector> outputs;
for (int i = num_skip_body_outputs; i < num_body_outs; i++) {
outputs.push_back({first_body_results[i]});
}
for (int i = 1; i < fw_node->get_internal_subgraphs_size(); i++) {
auto current_body_results = fw_node->get_function(i)->get_results();
for (int i = num_skip_body_outputs; i < num_body_outs; i++) {
outputs[i].push_back(current_body_results[i]);
}
}
for (const auto& res_vec : outputs) {
res.push_back(fw_node->set_body_outputs(res_vec));
}
}
// Propagate extra outputs to external context
for (const auto& output : extra_outputs_map) {
context->add_tensor_to_context(output.first, fw_node->set_body_outputs(output.second));
}
return res;
}

OutputVector convert_node(NodeContext* context) {
Expand Down Expand Up @@ -210,21 +254,16 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorc

auto node_visitor = [&](std::shared_ptr<Decoder> node) {
// Explore all inputs of node. Node may refer to input value that hasn't been created in the current scope.
// But this value can be found in the outer scope, for this purpose we need to search node in
// external_tensor_map as well
// But this value can be found in the outer scope, for this purpose we create new input for the model to
// link with external scope on a higher level.

// std::cout << "Node visitor start: " << node->get_op_type() << ", schema: " << node->get_schema() <<
// std::endl;
auto raw_inputs = node->inputs();
for (size_t i = 0; i < raw_inputs.size(); ++i) {
auto input = node->input(i);
if (tensor_map.find(input) == tensor_map.end()) {
// input refers value in the outer scope, need to create a new Parameter in the current scope
// TODO: Connect outer scope and inner scope properly -- should be handled at the level of that
// operation that introduced this nest of scopes (e.g. loop or if)
// TODO: Eliminate duplication with the main code for Parameters creation
// TODO: There is no real search for values in outer scope because we don't need to link the usage
// and definition together at this point -- need to do that otherwise graph will fall apart
// Input refers value in the outer scope, need to create a new Parameter in the current scope
// Linkage to external scope will be performed on the level of the parent operation (if or loop)
// TODO: Eliminate duplication with the main code for Parameters creation
PartialShape ps = node->get_input_shape(i);
auto type = simplified_type_interpret(node->get_input_type(i));
auto parameter = std::make_shared<opset8::Parameter>(element::custom, type, ps);
Expand All @@ -242,7 +281,7 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorc
mutated_tensors.insert(mutated_t.begin(), mutated_t.end());

auto fw_outputs = node->outputs();
// ops with subgraphs has more outputs
// Ops with subgraphs or with mutated inputs may have more outputs after conversion compared to pytorch ones
FRONT_END_OP_CONVERSION_CHECK(fw_outputs.size() <= converted_outputs.size(),
"Number of ",
node->get_op_type(),
Expand Down Expand Up @@ -272,9 +311,7 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorc
for (size_t i = 0; i < pytorch_model->num_of_outputs(); ++i) {
size_t id = pytorch_model->output(i);
if (tensor_map.find(id) == tensor_map.end()) {
// Not found in this scope, searching in the outer scope
// TODO: do real search here, skipped for now

// Not found in this scope, adding Parameter to connect to external scope
auto parameter = std::make_shared<opset8::Parameter>(element::dynamic, PartialShape::dynamic());
parameter->get_output_tensor(0).add_names({std::to_string(id)});
parameters.push_back(parameter);
Expand All @@ -285,8 +322,12 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorc
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
throw "Output strides have wrong order.";
}
// TODO: remove when all nodes has ids
ov_output.add_names({std::to_string(id)});
FRONT_END_GENERAL_CHECK(ov_output.get_names().size() > 0,
"Tensor doesn't have name, while it should have name: ",
id);
FRONT_END_GENERAL_CHECK(ov_output.get_any_name().find(std::to_string(id)) != std::string::npos,
"any_name of tensor doesn't contain actual name: ",
id);
auto result = std::make_shared<opset8::Result>(ov_output);
results.push_back(result);
}
Expand Down

0 comments on commit 98aa401

Please sign in to comment.