From 3f1013f5c544d1f95ba7cd6f5a6cbcd139afc4e3 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 22 Feb 2023 23:12:49 +0100 Subject: [PATCH 1/6] Add support for concatenation in Loop --- src/frontends/pytorch/src/op/cat.cpp | 28 ++++++++ .../pytorch/src/op/list_construct.cpp | 4 +- src/frontends/pytorch/src/op_table.cpp | 5 +- .../src/transforms/aten_cat_replacer.cpp | 69 ++++++++++++++++--- .../pytorch_tests/pytorch_layer_test_class.py | 3 +- tests/layer_tests/pytorch_tests/test_bool.py | 2 +- tests/layer_tests/pytorch_tests/test_cat.py | 51 ++++++++++++++ 7 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 src/frontends/pytorch/src/op/cat.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_cat.py diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp new file mode 100644 index 00000000000000..a769de9fdcc48f --- /dev/null +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "pt_framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_cat(NodeContext& context) { + // This translator is only needed to get axis as constant from external scope + num_inputs_check(context, 2, 2); + auto fw_node = std::make_shared(context.get_decoder(), OutputVector{context.get_input(0)}, 1); + auto attrs = fw_node->get_attrs(); + // If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline + attrs["axis"] = std::to_string(context.const_input(1)); + fw_node->set_attrs(attrs); + return {context.mark_node(std::dynamic_pointer_cast(fw_node))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/list_construct.cpp b/src/frontends/pytorch/src/op/list_construct.cpp index 2e1c3af328d7ff..e69188e23d89a3 100644 --- a/src/frontends/pytorch/src/op/list_construct.cpp +++ b/src/frontends/pytorch/src/op/list_construct.cpp @@ -31,7 +31,7 @@ OutputVector translate_list_construct(NodeContext& context) { consts.push_back(unsqueezed_c_node); } } - auto list_construct = std::make_shared(consts, 0); + auto list_construct = context.mark_node(std::make_shared(consts, 0)); if (list_construct->has_evaluate()) { OutputVector replacements(list_construct->get_output_size()); @@ -39,7 +39,7 @@ OutputVector translate_list_construct(NodeContext& context) { return replacements; } } - return {context.mark_output(list_construct)}; + return {list_construct}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 098afbfc9a6a8c..313af17156d869 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor); OP_CONVERTER(translate_avg_poolnd); OP_CONVERTER(translate_bool); OP_CONVERTER(translate_batch_norm); +OP_CONVERTER(translate_cat); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); @@ -93,6 +94,7 @@ OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_rsub); OP_CONVERTER(translate_select); +OP_CONVERTER(translate_set_attr); OP_CONVERTER(translate_set_item); OP_CONVERTER(translate_selu); OP_CONVERTER(translate_size); @@ -155,7 +157,7 @@ const std::map get_supported_ops() { {"aten::batch_norm", op::translate_batch_norm}, {"aten::bmm", op::translate_1to1_match_2_inputs}, {"aten::Bool", op::translate_bool}, - // {"aten::cat", done as transformation}, + {"aten::cat", op::translate_cat}, {"aten::ceil", op::translate_1to1_match_1_inputs}, {"aten::ceil_", op::inplace_op>}, {"aten::clamp", op::translate_clamp}, @@ -318,6 +320,7 @@ const std::map get_supported_ops() { {"prim::Loop", op::translate_loop}, {"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape [] {"prim::requires_grad", op::return_false_scalar}, + {"prim::SetAttr", op::translate_set_attr}, {"torchvision::nms", op::translate_nms}, }; }; diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index dfa4e6e819d892..f8bd7e8933e647 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -10,6 +10,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/loop.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() { if (!cat) return false; - auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr(); - auto axis_const = std::dynamic_pointer_cast(axis_node); - if (!axis_const) - return false; - auto axis = axis_const->cast_vector(); - if (axis.size() != 1) - return false; + int64_t axis; + if (cat->get_input_size() > 1) { + auto axis_node = cat->get_input_node_shared_ptr(1); + auto axis_const = std::dynamic_pointer_cast(axis_node); + if (!axis_const) + return false; + auto _axis = axis_const->cast_vector(); + if (_axis.size() != 1) + return false; + axis = _axis[0]; + } else { + const auto& attrs = cat->get_attrs(); + if (attrs.find("axis") == attrs.end()) + return false; + axis = std::stoll(attrs.at("axis")); + } + + std::shared_ptr input_node = cat->get_input_node_shared_ptr(0); + if (auto loop = std::dynamic_pointer_cast(input_node)) { + // case when concatenation is done inside the Loop + auto body = loop->get_function(); + auto output_index = cat->input(0).get_source_output().get_index(); + int body_result_index = -1; + for (auto out_desc : loop->get_output_descriptions()) { + if (out_desc->m_output_index == output_index) { + body_result_index = out_desc->m_body_value_index; + break; + } + } + FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output."); + auto body_result = body->get_results()[body_result_index]; + auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append"); + if (!append) + return false; + auto param = std::dynamic_pointer_cast(append->get_input_node_shared_ptr(0)); + if (!param) + return false; + auto body_param_index = body->get_parameter_index(param); + FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters."); + int input_index = -1; + for (auto in_desc : loop->get_input_descriptions()) { + if (in_desc->m_body_parameter_index == static_cast(body_param_index)) { + input_index = in_desc->m_input_index; + break; + } + } + FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input."); + auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct"); + if (!list_construct || list_construct->get_input_size() > 0) + return false; + // TODO: Is unsqueeze needed? + auto new_result = std::make_shared(append->input_value(1)); + body->add_results({new_result}); + auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis); + copy_runtime_info(cat, loop); + cat->output(0).replace(new_output); + return true; + } OutputVector tmp_inputs; NodeVector rt_copy_from{cat}; - std::shared_ptr input_node = cat->input(0).get_source_output().get_node_shared_ptr(); while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) { rt_copy_from.push_back(input_fw_node); tmp_inputs.push_back(input_fw_node->input(1).get_source_output()); @@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() { inputs.push_back(input.get_source_output()); } inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend()); - auto result = std::make_shared(inputs, axis[0]); + auto result = std::make_shared(inputs, axis); copy_runtime_info(rt_copy_from, result); replace_node(cat, result); diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 6cd0c1cc8fc28b..d5b7f27b70cd5d 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -52,7 +52,8 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti else: torch_inputs = [torch.from_numpy(inp) for inp in inputs] model = torch.jit.trace(model, torch_inputs) - model = torch.jit.freeze(model) + if kwargs.get('freeze_model', True): + model = torch.jit.freeze(model) graph = model.inlined_graph print(graph) diff --git a/tests/layer_tests/pytorch_tests/test_bool.py b/tests/layer_tests/pytorch_tests/test_bool.py index 60b509373bb209..fa13f9d8cc37ac 100644 --- a/tests/layer_tests/pytorch_tests/test_bool.py +++ b/tests/layer_tests/pytorch_tests/test_bool.py @@ -32,5 +32,5 @@ def forward_scalar(self, x:int): @pytest.mark.parametrize("input_type", ["tensor", "scalar"]) @pytest.mark.nightly @pytest.mark.precommit - def test_ceil(self, ie_device, precision, ir_version, input_type): + def test_bool(self, ie_device, precision, ir_version, input_type): self._test(*self.create_model(input_type), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_cat.py b/tests/layer_tests/pytorch_tests/test_cat.py new file mode 100644 index 00000000000000..476530b7f4aca2 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_cat.py @@ -0,0 +1,51 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class aten_cat(torch.nn.Module): + def forward(self, x): + return torch.cat([x, x], 1) + + +class aten_append_cat(torch.nn.Module): + def forward(self, x): + list = [] + list.append(x) + list.append(x) + return torch.cat(list, 1) + +class aten_loop_append_cat(torch.nn.Module): + def forward(self, x): + list = [] + for i in range(3): + list.append(x) + return torch.cat(list, 1) + + +class TestCat(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 2, 3),) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_cat(self, ie_device, precision, ir_version): + self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"], + ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_append_cat(self, ie_device, precision, ir_version): + self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"], + ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_loop_append_cat(self, ie_device, precision, ir_version): + self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"], + ie_device, precision, ir_version, freeze_model=False) From 48f137d03d1b9e78438b7b0051d0ae5c2ab07c76 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 23 Feb 2023 09:39:31 +0100 Subject: [PATCH 2/6] Apply suggestions from code review --- src/frontends/pytorch/src/op_table.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 313af17156d869..459a4aeb80ff9e 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -94,7 +94,6 @@ OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_rsub); OP_CONVERTER(translate_select); -OP_CONVERTER(translate_set_attr); OP_CONVERTER(translate_set_item); OP_CONVERTER(translate_selu); OP_CONVERTER(translate_size); @@ -320,7 +319,6 @@ const std::map get_supported_ops() { {"prim::Loop", op::translate_loop}, {"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape [] {"prim::requires_grad", op::return_false_scalar}, - {"prim::SetAttr", op::translate_set_attr}, {"torchvision::nms", op::translate_nms}, }; }; From 25381a9828c4a8a0d229a693a96c0635ccca2012 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 23 Feb 2023 17:05:47 +0100 Subject: [PATCH 3/6] Fix win build --- .../pytorch/src/transforms/aten_cat_replacer.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index f8bd7e8933e647..dd501c3ea351ba 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -60,10 +60,10 @@ AtenCatToConcat::AtenCatToConcat() { // case when concatenation is done inside the Loop auto body = loop->get_function(); auto output_index = cat->input(0).get_source_output().get_index(); - int body_result_index = -1; + int64_t body_result_index = -1; for (auto out_desc : loop->get_output_descriptions()) { if (out_desc->m_output_index == output_index) { - body_result_index = out_desc->m_body_value_index; + body_result_index = static_cast(out_desc->m_body_value_index); break; } } @@ -77,10 +77,10 @@ AtenCatToConcat::AtenCatToConcat() { return false; auto body_param_index = body->get_parameter_index(param); FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters."); - int input_index = -1; + int64_t input_index = -1; for (auto in_desc : loop->get_input_descriptions()) { - if (in_desc->m_body_parameter_index == static_cast(body_param_index)) { - input_index = in_desc->m_input_index; + if (in_desc->m_body_parameter_index == static_cast(body_param_index)) { + input_index = static_cast(in_desc->m_input_index); break; } } From 4fbf7e8b1f4e9a75006611ee925ba69fb4333610 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 27 Feb 2023 21:40:09 +0100 Subject: [PATCH 4/6] Fix issues with propagation shapes and types in Loop --- src/core/src/op/loop.cpp | 18 +++++++++++------- tests/layer_tests/pytorch_tests/test_cat.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/core/src/op/loop.cpp b/src/core/src/op/loop.cpp index 07744674e07c2c..7ed8fc8dcb3b44 100644 --- a/src/core/src/op/loop.cpp +++ b/src/core/src/op/loop.cpp @@ -42,11 +42,11 @@ void op::v5::Loop::validate_and_infer_types() { "Loop contains output descriptions for other bodies"); if (m_special_body_ports.current_iteration_input_idx >= 0) { - const auto& cur_iter_rank = m_bodies[0] - ->get_parameters() - .at(m_special_body_ports.current_iteration_input_idx) - ->get_partial_shape() - .rank(); + // Propagate current_iteration input shape and type + auto& iter_param = m_bodies[0]->get_parameters().at(m_special_body_ports.current_iteration_input_idx); + iter_param->set_element_type(get_input_element_type(0)); + iter_param->set_partial_shape(get_input_partial_shape(0)); + const auto& cur_iter_rank = iter_param->get_partial_shape().rank(); if (cur_iter_rank.is_static()) { NODE_VALIDATION_CHECK(this, cur_iter_rank.compatible(1) || cur_iter_rank.compatible(0), @@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() { if (auto slice_input_description = ov::as_type_ptr(input_description)) { auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index); const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); + const auto& input_type = inputs().at(index).get_source_output().get_element_type(); + body_parameter->set_element_type(input_type); if (input_partial_shape.rank().is_dynamic()) { body_parameter->set_partial_shape(ov::PartialShape::dynamic()); } else { @@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() { auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index); - auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); + auto input_type = input(index).get_element_type(); body_parameter->set_partial_shape(input_partial_shape); + body_parameter->set_element_type(input_type); back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index; } else if (auto invariant_input_description = ov::as_type_ptr(input_description)) { auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index); - auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); + auto input_type = input(index).get_element_type(); body_parameter->set_partial_shape(input_partial_shape); + body_parameter->set_element_type(input_type); } } diff --git a/tests/layer_tests/pytorch_tests/test_cat.py b/tests/layer_tests/pytorch_tests/test_cat.py index 476530b7f4aca2..772be8d88b40c3 100644 --- a/tests/layer_tests/pytorch_tests/test_cat.py +++ b/tests/layer_tests/pytorch_tests/test_cat.py @@ -30,7 +30,7 @@ def forward(self, x): class TestCat(PytorchLayerTest): def _prepare_input(self): import numpy as np - return (np.random.randn(1, 2, 3),) + return (np.random.randn(2, 1, 3),) @pytest.mark.nightly @pytest.mark.precommit From 4c3ae1373f28e1e4559938fdf4d7cbf73dcbb7f0 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 27 Feb 2023 23:49:42 +0100 Subject: [PATCH 5/6] Fix einsum --- src/core/src/op/einsum.cpp | 2 +- src/frontends/pytorch/src/op/loop.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/core/src/op/einsum.cpp b/src/core/src/op/einsum.cpp index 41f4aed079fdad..4ec85e5321faa1 100644 --- a/src/core/src/op/einsum.cpp +++ b/src/core/src/op/einsum.cpp @@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() { for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) { const auto& input_type_i = get_input_element_type(input_idx); NODE_VALIDATION_CHECK(this, - input_type_0 == input_type_i, + input_type_0.compatible(input_type_i), "Inputs to Einsum operation must have the same type."); } diff --git a/src/frontends/pytorch/src/op/loop.cpp b/src/frontends/pytorch/src/op/loop.cpp index 75107b9503b0c1..8039064acf028a 100644 --- a/src/frontends/pytorch/src/op/loop.cpp +++ b/src/frontends/pytorch/src/op/loop.cpp @@ -39,7 +39,6 @@ OutputVector translate_loop(NodeContext& context) { auto external_output = context.get_tensor_from_model_or_create_input(input_idx); loop->set_invariant_inputs(external_output, {param}); } - // TODO: Connect back edges (merged inputs) auto body_results = body->get_results(); FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition."); std::set output_idxs; From 2b469bf03fa25175881f99a3f7bdf6c2ef67d3d3 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 28 Feb 2023 10:56:49 +0100 Subject: [PATCH 6/6] Set type and shape of count in frontend --- src/core/src/op/loop.cpp | 10 +++++----- src/frontends/pytorch/src/op/loop.cpp | 7 ++++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/core/src/op/loop.cpp b/src/core/src/op/loop.cpp index 7ed8fc8dcb3b44..34ada949fd4988 100644 --- a/src/core/src/op/loop.cpp +++ b/src/core/src/op/loop.cpp @@ -42,11 +42,11 @@ void op::v5::Loop::validate_and_infer_types() { "Loop contains output descriptions for other bodies"); if (m_special_body_ports.current_iteration_input_idx >= 0) { - // Propagate current_iteration input shape and type - auto& iter_param = m_bodies[0]->get_parameters().at(m_special_body_ports.current_iteration_input_idx); - iter_param->set_element_type(get_input_element_type(0)); - iter_param->set_partial_shape(get_input_partial_shape(0)); - const auto& cur_iter_rank = iter_param->get_partial_shape().rank(); + const auto& cur_iter_rank = m_bodies[0] + ->get_parameters() + .at(m_special_body_ports.current_iteration_input_idx) + ->get_partial_shape() + .rank(); if (cur_iter_rank.is_static()) { NODE_VALIDATION_CHECK(this, cur_iter_rank.compatible(1) || cur_iter_rank.compatible(0), diff --git a/src/frontends/pytorch/src/op/loop.cpp b/src/frontends/pytorch/src/op/loop.cpp index 8039064acf028a..7bf03cfcd30138 100644 --- a/src/frontends/pytorch/src/op/loop.cpp +++ b/src/frontends/pytorch/src/op/loop.cpp @@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) { loop->set_special_body_ports(spec_ports); auto body_parameters = body->get_parameters(); - // #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition + // #0 body parameter is counter; + FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required"); + // Set counter type and shape + body_parameters[0]->set_element_type(element::i32); + body_parameters[0]->set_partial_shape(PartialShape{}); + // #0 loop input is trip_count, #1 loop input is condition // Connect other inputs for (size_t i = 2; i < inputs.size(); i++) { loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});