diff --git a/src/core/src/op/einsum.cpp b/src/core/src/op/einsum.cpp index 41f4aed079fdad..4ec85e5321faa1 100644 --- a/src/core/src/op/einsum.cpp +++ b/src/core/src/op/einsum.cpp @@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() { for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) { const auto& input_type_i = get_input_element_type(input_idx); NODE_VALIDATION_CHECK(this, - input_type_0 == input_type_i, + input_type_0.compatible(input_type_i), "Inputs to Einsum operation must have the same type."); } diff --git a/src/core/src/op/loop.cpp b/src/core/src/op/loop.cpp index 07744674e07c2c..34ada949fd4988 100644 --- a/src/core/src/op/loop.cpp +++ b/src/core/src/op/loop.cpp @@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() { if (auto slice_input_description = ov::as_type_ptr(input_description)) { auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index); const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); + const auto& input_type = inputs().at(index).get_source_output().get_element_type(); + body_parameter->set_element_type(input_type); if (input_partial_shape.rank().is_dynamic()) { body_parameter->set_partial_shape(ov::PartialShape::dynamic()); } else { @@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() { auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index); - auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); + auto input_type = input(index).get_element_type(); body_parameter->set_partial_shape(input_partial_shape); + body_parameter->set_element_type(input_type); back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index; } else if (auto invariant_input_description = ov::as_type_ptr(input_description)) { auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index); - auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); + auto input_type = input(index).get_element_type(); body_parameter->set_partial_shape(input_partial_shape); + body_parameter->set_element_type(input_type); } } diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp new file mode 100644 index 00000000000000..a769de9fdcc48f --- /dev/null +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "pt_framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_cat(NodeContext& context) { + // This translator is only needed to get axis as constant from external scope + num_inputs_check(context, 2, 2); + auto fw_node = std::make_shared(context.get_decoder(), OutputVector{context.get_input(0)}, 1); + auto attrs = fw_node->get_attrs(); + // If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline + attrs["axis"] = std::to_string(context.const_input(1)); + fw_node->set_attrs(attrs); + return {context.mark_node(std::dynamic_pointer_cast(fw_node))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/list_construct.cpp b/src/frontends/pytorch/src/op/list_construct.cpp index 2e1c3af328d7ff..e69188e23d89a3 100644 --- a/src/frontends/pytorch/src/op/list_construct.cpp +++ b/src/frontends/pytorch/src/op/list_construct.cpp @@ -31,7 +31,7 @@ OutputVector translate_list_construct(NodeContext& context) { consts.push_back(unsqueezed_c_node); } } - auto list_construct = std::make_shared(consts, 0); + auto list_construct = context.mark_node(std::make_shared(consts, 0)); if (list_construct->has_evaluate()) { OutputVector replacements(list_construct->get_output_size()); @@ -39,7 +39,7 @@ OutputVector translate_list_construct(NodeContext& context) { return replacements; } } - return {context.mark_output(list_construct)}; + return {list_construct}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op/loop.cpp b/src/frontends/pytorch/src/op/loop.cpp index 75107b9503b0c1..7bf03cfcd30138 100644 --- a/src/frontends/pytorch/src/op/loop.cpp +++ b/src/frontends/pytorch/src/op/loop.cpp @@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) { loop->set_special_body_ports(spec_ports); auto body_parameters = body->get_parameters(); - // #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition + // #0 body parameter is counter; + FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required"); + // Set counter type and shape + body_parameters[0]->set_element_type(element::i32); + body_parameters[0]->set_partial_shape(PartialShape{}); + // #0 loop input is trip_count, #1 loop input is condition // Connect other inputs for (size_t i = 2; i < inputs.size(); i++) { loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]}); @@ -39,7 +44,6 @@ OutputVector translate_loop(NodeContext& context) { auto external_output = context.get_tensor_from_model_or_create_input(input_idx); loop->set_invariant_inputs(external_output, {param}); } - // TODO: Connect back edges (merged inputs) auto body_results = body->get_results(); FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition."); std::set output_idxs; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index bd2e9bf0564e7b..3a2c55eab270ea 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor); OP_CONVERTER(translate_avg_poolnd); OP_CONVERTER(translate_bool); OP_CONVERTER(translate_batch_norm); +OP_CONVERTER(translate_cat); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); @@ -160,7 +161,7 @@ const std::map get_supported_ops() { {"aten::batch_norm", op::translate_batch_norm}, {"aten::bmm", op::translate_1to1_match_2_inputs}, {"aten::Bool", op::translate_bool}, - // {"aten::cat", done as transformation}, + {"aten::cat", op::translate_cat}, {"aten::ceil", op::translate_1to1_match_1_inputs}, {"aten::ceil_", op::inplace_op>}, {"aten::clamp", op::translate_clamp}, diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index dfa4e6e819d892..dd501c3ea351ba 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -10,6 +10,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/loop.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() { if (!cat) return false; - auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr(); - auto axis_const = std::dynamic_pointer_cast(axis_node); - if (!axis_const) - return false; - auto axis = axis_const->cast_vector(); - if (axis.size() != 1) - return false; + int64_t axis; + if (cat->get_input_size() > 1) { + auto axis_node = cat->get_input_node_shared_ptr(1); + auto axis_const = std::dynamic_pointer_cast(axis_node); + if (!axis_const) + return false; + auto _axis = axis_const->cast_vector(); + if (_axis.size() != 1) + return false; + axis = _axis[0]; + } else { + const auto& attrs = cat->get_attrs(); + if (attrs.find("axis") == attrs.end()) + return false; + axis = std::stoll(attrs.at("axis")); + } + + std::shared_ptr input_node = cat->get_input_node_shared_ptr(0); + if (auto loop = std::dynamic_pointer_cast(input_node)) { + // case when concatenation is done inside the Loop + auto body = loop->get_function(); + auto output_index = cat->input(0).get_source_output().get_index(); + int64_t body_result_index = -1; + for (auto out_desc : loop->get_output_descriptions()) { + if (out_desc->m_output_index == output_index) { + body_result_index = static_cast(out_desc->m_body_value_index); + break; + } + } + FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output."); + auto body_result = body->get_results()[body_result_index]; + auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append"); + if (!append) + return false; + auto param = std::dynamic_pointer_cast(append->get_input_node_shared_ptr(0)); + if (!param) + return false; + auto body_param_index = body->get_parameter_index(param); + FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters."); + int64_t input_index = -1; + for (auto in_desc : loop->get_input_descriptions()) { + if (in_desc->m_body_parameter_index == static_cast(body_param_index)) { + input_index = static_cast(in_desc->m_input_index); + break; + } + } + FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input."); + auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct"); + if (!list_construct || list_construct->get_input_size() > 0) + return false; + // TODO: Is unsqueeze needed? + auto new_result = std::make_shared(append->input_value(1)); + body->add_results({new_result}); + auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis); + copy_runtime_info(cat, loop); + cat->output(0).replace(new_output); + return true; + } OutputVector tmp_inputs; NodeVector rt_copy_from{cat}; - std::shared_ptr input_node = cat->input(0).get_source_output().get_node_shared_ptr(); while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) { rt_copy_from.push_back(input_fw_node); tmp_inputs.push_back(input_fw_node->input(1).get_source_output()); @@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() { inputs.push_back(input.get_source_output()); } inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend()); - auto result = std::make_shared(inputs, axis[0]); + auto result = std::make_shared(inputs, axis); copy_runtime_info(rt_copy_from, result); replace_node(cat, result); diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 6cd0c1cc8fc28b..d5b7f27b70cd5d 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -52,7 +52,8 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti else: torch_inputs = [torch.from_numpy(inp) for inp in inputs] model = torch.jit.trace(model, torch_inputs) - model = torch.jit.freeze(model) + if kwargs.get('freeze_model', True): + model = torch.jit.freeze(model) graph = model.inlined_graph print(graph) diff --git a/tests/layer_tests/pytorch_tests/test_bool.py b/tests/layer_tests/pytorch_tests/test_bool.py index 60b509373bb209..fa13f9d8cc37ac 100644 --- a/tests/layer_tests/pytorch_tests/test_bool.py +++ b/tests/layer_tests/pytorch_tests/test_bool.py @@ -32,5 +32,5 @@ def forward_scalar(self, x:int): @pytest.mark.parametrize("input_type", ["tensor", "scalar"]) @pytest.mark.nightly @pytest.mark.precommit - def test_ceil(self, ie_device, precision, ir_version, input_type): + def test_bool(self, ie_device, precision, ir_version, input_type): self._test(*self.create_model(input_type), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_cat.py b/tests/layer_tests/pytorch_tests/test_cat.py new file mode 100644 index 00000000000000..772be8d88b40c3 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_cat.py @@ -0,0 +1,51 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class aten_cat(torch.nn.Module): + def forward(self, x): + return torch.cat([x, x], 1) + + +class aten_append_cat(torch.nn.Module): + def forward(self, x): + list = [] + list.append(x) + list.append(x) + return torch.cat(list, 1) + +class aten_loop_append_cat(torch.nn.Module): + def forward(self, x): + list = [] + for i in range(3): + list.append(x) + return torch.cat(list, 1) + + +class TestCat(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 1, 3),) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_cat(self, ie_device, precision, ir_version): + self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"], + ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_append_cat(self, ie_device, precision, ir_version): + self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"], + ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_loop_append_cat(self, ie_device, precision, ir_version): + self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"], + ie_device, precision, ir_version, freeze_model=False)