diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index decab9e6df1e15..0910aa3e057e72 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -178,6 +178,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/pt_framework_node.hpp b/src/frontends/pytorch/src/pt_framework_node.hpp index 9dc78c0f0ae59b..04b71d1169ae81 100644 --- a/src/frontends/pytorch/src/pt_framework_node.hpp +++ b/src/frontends/pytorch/src/pt_framework_node.hpp @@ -75,8 +75,8 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode { return m_decoder->get_op_type(); } - TorchDecoder* get_decoder() const { - return m_decoder.get(); + std::shared_ptr get_decoder() const { + return m_decoder; } private: diff --git a/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp index ae1c1aa3379d84..9352d148e823f7 100644 --- a/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp @@ -4,9 +4,11 @@ #include "tuple_unpack_replacer.hpp" +#include "openvino/op/if.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "pt_framework_node.hpp" #include "utils.hpp" namespace ov { @@ -14,6 +16,8 @@ namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; + PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() { auto tuple_unpack = ov::pass::pattern::wrap_type(); @@ -41,6 +45,144 @@ PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() { this->register_matcher(m, callback); }; +bool TupleUnpackInBodyReplacer::run_on_model(const std::shared_ptr& model) { + bool result = false; + for (auto op : model->get_ordered_ops()) { + const auto if_op = as_type_ptr(op); + if (if_op) { + for (size_t i = 1; i < if_op->get_input_size(); i++) { + auto input = if_op->input_value(i); + auto tuple_construct = std::dynamic_pointer_cast( + cast_fw_node(input.get_node_shared_ptr(), "prim::TupleConstruct")); + if (!tuple_construct) { + continue; + } + int then_body_idx = -1; + int else_body_idx = -1; + auto then_descs = if_op->get_input_descriptions(v8::If::THEN_BODY_INDEX); + auto else_descs = if_op->get_input_descriptions(v8::If::ELSE_BODY_INDEX); + for (auto inp_desc : then_descs) { + if (inp_desc->m_input_index == i) { + if (then_body_idx != -1) { + add_exception_to_fw_node( + tuple_construct, + "Unexpected: TupleConstruct output is used in body more then once."); + } else { + then_body_idx = static_cast(inp_desc->m_body_parameter_index); + } + } + } + for (auto inp_desc : else_descs) { + if (inp_desc->m_input_index == i) { + if (else_body_idx != -1) { + add_exception_to_fw_node( + tuple_construct, + "Unexpected: TupleConstruct output is used in body more then once."); + } else { + else_body_idx = static_cast(inp_desc->m_body_parameter_index); + } + } + } + auto new_if = std::make_shared(if_op->input_value(0)); + auto then_body = if_op->get_function(v8::If::THEN_BODY_INDEX); + auto else_body = if_op->get_function(v8::If::ELSE_BODY_INDEX); + ov::ParameterVector new_then_params; + ov::ParameterVector new_else_params; + if (then_body_idx != -1) { + auto then_param = then_body->get_parameters().at(then_body_idx); + ov::OutputVector new_tc_inputs; + for (size_t i = 0; i < tuple_construct->get_input_size(); i++) { + auto new_param = std::make_shared(element::dynamic, PartialShape::dynamic()); + new_then_params.push_back(new_param); + new_tc_inputs.push_back(new_param); + } + auto new_tc = + std::make_shared(tuple_construct->get_decoder(), + new_tc_inputs, + 1); + then_body->add_parameters(new_then_params); + then_body->remove_parameter(then_param); + then_param->output(0).replace(new_tc->output(0)); + } + if (else_body_idx != -1) { + auto else_param = else_body->get_parameters().at(else_body_idx); + ov::OutputVector new_tc_inputs; + for (size_t i = 0; i < tuple_construct->get_input_size(); i++) { + auto new_param = std::make_shared(element::dynamic, PartialShape::dynamic()); + new_else_params.push_back(new_param); + new_tc_inputs.push_back(new_param); + } + auto new_tc = + std::make_shared(tuple_construct->get_decoder(), + new_tc_inputs, + 1); + else_body->add_parameters(new_else_params); + else_body->remove_parameter(else_param); + else_param->output(0).replace(new_tc->output(0)); + } + new_if->set_function(v8::If::THEN_BODY_INDEX, then_body); + new_if->set_function(v8::If::ELSE_BODY_INDEX, else_body); + new_if->set_output_size(if_op->get_output_size()); + new_if->set_output_descriptions(v8::If::THEN_BODY_INDEX, + if_op->get_output_descriptions(v8::If::THEN_BODY_INDEX)); + new_if->set_output_descriptions(v8::If::ELSE_BODY_INDEX, + if_op->get_output_descriptions(v8::If::ELSE_BODY_INDEX)); + + // create new If inputs + std::vector> inputs_mapping(if_op->get_input_size(), {-1, -1}); + for (auto inp_desc : then_descs) { + inputs_mapping[inp_desc->m_input_index].first = static_cast(inp_desc->m_body_parameter_index); + } + for (auto inp_desc : else_descs) { + inputs_mapping[inp_desc->m_input_index].second = static_cast(inp_desc->m_body_parameter_index); + } + for (size_t j = 0; j < inputs_mapping.size(); j++) { + if (j == i) + continue; + int then_p_idx = inputs_mapping[j].first; + if (then_p_idx > then_body_idx && then_body_idx != -1) + then_p_idx--; + int else_p_idx = inputs_mapping[j].second; + if (else_p_idx > else_body_idx && else_body_idx != -1) + else_p_idx--; + auto then_p = then_p_idx == -1 ? nullptr : then_body->get_parameters()[then_p_idx]; + auto else_p = else_p_idx == -1 ? nullptr : else_body->get_parameters()[else_p_idx]; + if (then_p || else_p) + new_if->set_invariant_inputs(if_op->input_value(j), {then_p, else_p}); + } + for (size_t j = 0; j < tuple_construct->get_input_size(); j++) { + ParameterVector body_inps; + if (then_body_idx != -1) { + FRONT_END_GENERAL_CHECK(j < new_then_params.size(), "Unexpected number of Parameters."); + body_inps.push_back(new_then_params[j]); + } else { + body_inps.push_back(nullptr); + } + if (else_body_idx != -1) { + FRONT_END_GENERAL_CHECK(j < new_else_params.size(), "Unexpected number of Parameters."); + body_inps.push_back(new_else_params[j]); + } else { + body_inps.push_back(nullptr); + } + new_if->set_invariant_inputs(tuple_construct->input_value(j), body_inps); + } + new_if->set_friendly_name(if_op->get_friendly_name()); + replace_node(if_op, new_if); + new_if->validate_and_infer_types(); + op = std::dynamic_pointer_cast(new_if); + result = true; + break; + } + } + if (const auto multiSubGraph = ov::as_type_ptr(op)) { + for (size_t i = 0; i < multiSubGraph->get_internal_subgraphs_size(); i++) + result = result || run_on_model(multiSubGraph->get_function(i)); + } + } + + return result; +}; + } // namespace pass } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.hpp b/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.hpp index 012c8e8c05ff0e..81aae1eefaf4d6 100644 --- a/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.hpp +++ b/src/frontends/pytorch/src/transforms/tuple_unpack_replacer.hpp @@ -18,6 +18,12 @@ class PrimTupleUnpackReplacer : public ov::pass::MatcherPass { PrimTupleUnpackReplacer(); }; +class TupleUnpackInBodyReplacer : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer"); + bool run_on_model(const std::shared_ptr& model) override; +}; + } // namespace pass } // namespace pytorch } // namespace frontend diff --git a/tests/layer_tests/pytorch_tests/test_tuple_construct.py b/tests/layer_tests/pytorch_tests/test_tuple_construct.py index b4f48354dcfdb6..1582df48c4b370 100644 --- a/tests/layer_tests/pytorch_tests/test_tuple_construct.py +++ b/tests/layer_tests/pytorch_tests/test_tuple_construct.py @@ -198,7 +198,7 @@ def create_model(self): class model(torch.nn.Module): def forward(self, x): - return self.some_func((x,x)) + return self.some_func((x, x)) def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]): return x[1] * 2, x[0] * 3 @@ -209,3 +209,31 @@ def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]): def test(self, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=False, freeze_model=False) + + +class TestTcOutsideTuInsideIfBody(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(1, 2, 10).astype(np.float32), np.random.randn(1, 2, 10).astype(np.float32)) + + def create_model(self): + import torch + from typing import Tuple + + class model(torch.nn.Module): + def forward(self, x, y): + return self.some_func((x, y)) + + def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]): + if x[0].numel() > 10: + n, m = x + return n * m + else: + n, m = x + return n - m + + return model(), None, ["prim::TupleConstruct", "prim::TupleUnpack", "prim::If"] + + @pytest.mark.nightly + def test(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, + ir_version, trace_model=False, freeze_model=False)