Skip to content

Commit

Permalink
[PT FE] Support moving TupleConstruct inside If body (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#20081)

* Support moving TupleConstruct inside If body

* Fix win build

---------

Co-authored-by: Alina Kladieva <[email protected]>
  • Loading branch information
mvafin and akladiev authored Sep 28, 2023
1 parent 64cc3a9 commit f38b5f4
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::pass::ConstantFolding>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/pt_framework_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {
return m_decoder->get_op_type();
}

TorchDecoder* get_decoder() const {
return m_decoder.get();
std::shared_ptr<TorchDecoder> get_decoder() const {
return m_decoder;
}

private:
Expand Down
142 changes: 142 additions & 0 deletions src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

#include "tuple_unpack_replacer.hpp"

#include "openvino/op/if.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::op;

PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
auto tuple_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();

Expand Down Expand Up @@ -41,6 +45,144 @@ PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
this->register_matcher(m, callback);
};

bool TupleUnpackInBodyReplacer::run_on_model(const std::shared_ptr<Model>& model) {
bool result = false;
for (auto op : model->get_ordered_ops()) {
const auto if_op = as_type_ptr<v8::If>(op);
if (if_op) {
for (size_t i = 1; i < if_op->get_input_size(); i++) {
auto input = if_op->input_value(i);
auto tuple_construct = std::dynamic_pointer_cast<ov::frontend::pytorch::PtFrameworkNode>(
cast_fw_node(input.get_node_shared_ptr(), "prim::TupleConstruct"));
if (!tuple_construct) {
continue;
}
int then_body_idx = -1;
int else_body_idx = -1;
auto then_descs = if_op->get_input_descriptions(v8::If::THEN_BODY_INDEX);
auto else_descs = if_op->get_input_descriptions(v8::If::ELSE_BODY_INDEX);
for (auto inp_desc : then_descs) {
if (inp_desc->m_input_index == i) {
if (then_body_idx != -1) {
add_exception_to_fw_node(
tuple_construct,
"Unexpected: TupleConstruct output is used in body more then once.");
} else {
then_body_idx = static_cast<int>(inp_desc->m_body_parameter_index);
}
}
}
for (auto inp_desc : else_descs) {
if (inp_desc->m_input_index == i) {
if (else_body_idx != -1) {
add_exception_to_fw_node(
tuple_construct,
"Unexpected: TupleConstruct output is used in body more then once.");
} else {
else_body_idx = static_cast<int>(inp_desc->m_body_parameter_index);
}
}
}
auto new_if = std::make_shared<v8::If>(if_op->input_value(0));
auto then_body = if_op->get_function(v8::If::THEN_BODY_INDEX);
auto else_body = if_op->get_function(v8::If::ELSE_BODY_INDEX);
ov::ParameterVector new_then_params;
ov::ParameterVector new_else_params;
if (then_body_idx != -1) {
auto then_param = then_body->get_parameters().at(then_body_idx);
ov::OutputVector new_tc_inputs;
for (size_t i = 0; i < tuple_construct->get_input_size(); i++) {
auto new_param = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic());
new_then_params.push_back(new_param);
new_tc_inputs.push_back(new_param);
}
auto new_tc =
std::make_shared<ov::frontend::pytorch::PtFrameworkNode>(tuple_construct->get_decoder(),
new_tc_inputs,
1);
then_body->add_parameters(new_then_params);
then_body->remove_parameter(then_param);
then_param->output(0).replace(new_tc->output(0));
}
if (else_body_idx != -1) {
auto else_param = else_body->get_parameters().at(else_body_idx);
ov::OutputVector new_tc_inputs;
for (size_t i = 0; i < tuple_construct->get_input_size(); i++) {
auto new_param = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic());
new_else_params.push_back(new_param);
new_tc_inputs.push_back(new_param);
}
auto new_tc =
std::make_shared<ov::frontend::pytorch::PtFrameworkNode>(tuple_construct->get_decoder(),
new_tc_inputs,
1);
else_body->add_parameters(new_else_params);
else_body->remove_parameter(else_param);
else_param->output(0).replace(new_tc->output(0));
}
new_if->set_function(v8::If::THEN_BODY_INDEX, then_body);
new_if->set_function(v8::If::ELSE_BODY_INDEX, else_body);
new_if->set_output_size(if_op->get_output_size());
new_if->set_output_descriptions(v8::If::THEN_BODY_INDEX,
if_op->get_output_descriptions(v8::If::THEN_BODY_INDEX));
new_if->set_output_descriptions(v8::If::ELSE_BODY_INDEX,
if_op->get_output_descriptions(v8::If::ELSE_BODY_INDEX));

// create new If inputs
std::vector<std::pair<int, int>> inputs_mapping(if_op->get_input_size(), {-1, -1});
for (auto inp_desc : then_descs) {
inputs_mapping[inp_desc->m_input_index].first = static_cast<int>(inp_desc->m_body_parameter_index);
}
for (auto inp_desc : else_descs) {
inputs_mapping[inp_desc->m_input_index].second = static_cast<int>(inp_desc->m_body_parameter_index);
}
for (size_t j = 0; j < inputs_mapping.size(); j++) {
if (j == i)
continue;
int then_p_idx = inputs_mapping[j].first;
if (then_p_idx > then_body_idx && then_body_idx != -1)
then_p_idx--;
int else_p_idx = inputs_mapping[j].second;
if (else_p_idx > else_body_idx && else_body_idx != -1)
else_p_idx--;
auto then_p = then_p_idx == -1 ? nullptr : then_body->get_parameters()[then_p_idx];
auto else_p = else_p_idx == -1 ? nullptr : else_body->get_parameters()[else_p_idx];
if (then_p || else_p)
new_if->set_invariant_inputs(if_op->input_value(j), {then_p, else_p});
}
for (size_t j = 0; j < tuple_construct->get_input_size(); j++) {
ParameterVector body_inps;
if (then_body_idx != -1) {
FRONT_END_GENERAL_CHECK(j < new_then_params.size(), "Unexpected number of Parameters.");
body_inps.push_back(new_then_params[j]);
} else {
body_inps.push_back(nullptr);
}
if (else_body_idx != -1) {
FRONT_END_GENERAL_CHECK(j < new_else_params.size(), "Unexpected number of Parameters.");
body_inps.push_back(new_else_params[j]);
} else {
body_inps.push_back(nullptr);
}
new_if->set_invariant_inputs(tuple_construct->input_value(j), body_inps);
}
new_if->set_friendly_name(if_op->get_friendly_name());
replace_node(if_op, new_if);
new_if->validate_and_infer_types();
op = std::dynamic_pointer_cast<Node>(new_if);
result = true;
break;
}
}
if (const auto multiSubGraph = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(op)) {
for (size_t i = 0; i < multiSubGraph->get_internal_subgraphs_size(); i++)
result = result || run_on_model(multiSubGraph->get_function(i));
}
}

return result;
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class PrimTupleUnpackReplacer : public ov::pass::MatcherPass {
PrimTupleUnpackReplacer();
};

class TupleUnpackInBodyReplacer : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer");
bool run_on_model(const std::shared_ptr<Model>& model) override;
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
Expand Down
30 changes: 29 additions & 1 deletion tests/layer_tests/pytorch_tests/test_tuple_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def create_model(self):

class model(torch.nn.Module):
def forward(self, x):
return self.some_func((x,x))
return self.some_func((x, x))

def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]):
return x[1] * 2, x[0] * 3
Expand All @@ -209,3 +209,31 @@ def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]):
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)


class TestTcOutsideTuInsideIfBody(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 2, 10).astype(np.float32), np.random.randn(1, 2, 10).astype(np.float32))

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):
def forward(self, x, y):
return self.some_func((x, y))

def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]):
if x[0].numel() > 10:
n, m = x
return n * m
else:
n, m = x
return n - m

return model(), None, ["prim::TupleConstruct", "prim::TupleUnpack", "prim::If"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)

0 comments on commit f38b5f4

Please sign in to comment.