From 400e75f0e6523eb42a18c2a8e69e8b965cb5cc6d Mon Sep 17 00:00:00 2001 From: Lyamin-Roman Date: Thu, 14 Mar 2024 11:20:55 +0900 Subject: [PATCH] [Transformations] Added If operation to NMS path propagation for ignore negative indices in Gather --- .../convert_nms_gather_path_to_unsigned.cpp | 52 +++++++++++++-- ...nvert_nms_gather_path_to_unsigned_test.cpp | 65 +++++++++++++++++++ 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp index 1cd38b4caa0b37..55c563ba166dc8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp @@ -10,6 +10,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/convert.hpp" +#include "openvino/op/if.hpp" #include "openvino/op/non_max_suppression.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/slice.hpp" @@ -19,8 +20,10 @@ #include "openvino/op/util/broadcast_base.hpp" #include "openvino/op/util/gather_base.hpp" #include "openvino/op/variadic_split.hpp" +#include "openvino/pass/manager.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/rt_info/nms_selected_indices.hpp" +#include "openvino/op/util/multi_subgraph_base.hpp" using namespace std; @@ -60,14 +63,51 @@ class PropagateNMSPath : public pass::MatcherPass { ov::op::v1::VariadicSplit, op::util::GatherBase, ov::op::v0::Concat, - ov::op::v0::Convert>(); + ov::op::v0::Convert, + ov::op::v8::If>(); matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto propagate_path = [](const ov::OutputVector& input_nodes, ov::Node* target_node) { + if (any_of(input_nodes.begin(), input_nodes.end(), [](const Output& output) { + return ov::has_nms_selected_indices(output.get_node()); + })) { + ov::set_nms_selected_indices(target_node); + } + }; + auto handle_params = + [&propagate_path](std::shared_ptr if_node, std::shared_ptr body, int body_index) { + const auto& params = body->get_parameters(); + for (auto input_desc : if_node->get_input_descriptions(body_index)) { + auto param = params[input_desc->m_body_parameter_index]; + auto input_node = if_node->input(input_desc->m_input_index).get_source_output(); + propagate_path({input_node}, param.get()); + } + }; + auto handle_results = + [&propagate_path](std::shared_ptr if_node, std::shared_ptr body, int body_index) { + const auto& results = body->get_results(); + for (auto output_desc : if_node->get_output_descriptions(body_index)) { + auto result = results[output_desc->m_body_value_index]; + const auto& result_inputs = result->input_values(); + auto output_node = if_node->output(output_desc->m_output_index).get_node(); + propagate_path(result_inputs, output_node); + } + }; + auto node = m.get_match_root(); - const auto& inputs = node->input_values(); - if (any_of(inputs.begin(), inputs.end(), [](const Output& output) { - return ov::has_nms_selected_indices(output.get_node()); - })) { - ov::set_nms_selected_indices(node.get()); + if (ov::is_type(node)) { + auto multi_subgraph_op = ov::as_type_ptr(node); + const auto& models = multi_subgraph_op->get_functions(); + + for (size_t body_idx = 0; body_idx < models.size(); ++body_idx) { + handle_params(multi_subgraph_op, models[body_idx], static_cast(body_idx)); + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(models[body_idx]); + handle_results(multi_subgraph_op, models[body_idx], static_cast(body_idx)); + } + } else { + const auto& inputs = node->input_values(); + propagate_path(inputs, node.get()); } return false; }; diff --git a/src/common/transformations/tests/common_optimizations/convert_nms_gather_path_to_unsigned_test.cpp b/src/common/transformations/tests/common_optimizations/convert_nms_gather_path_to_unsigned_test.cpp index e9763e4d6bfec1..3076e32646eaa1 100644 --- a/src/common/transformations/tests/common_optimizations/convert_nms_gather_path_to_unsigned_test.cpp +++ b/src/common/transformations/tests/common_optimizations/convert_nms_gather_path_to_unsigned_test.cpp @@ -205,3 +205,68 @@ TEST(TransformationTests, test_convert_to_unsigned_nms_gather_3) { ASSERT_NO_THROW(check_rt_info(f)); ASSERT_EQ(count_ops_of_type(f), 0); } + +TEST(TransformationTests, test_convert_to_unsigned_nms_gather_with_if_condition) { + auto boxes = make_shared(element::f32, PartialShape{1, -1, 4}); + auto scores = make_shared(element::f32, PartialShape{1, 1, -1}); + auto nms = make_shared(boxes, scores); + + auto gather = make_shared(nms->output(0), + opset8::Constant::create(element::i32, Shape{1}, {2}), + opset8::Constant::create(element::i32, Shape{1}, {0})); + + auto shape_of = make_shared(gather); + auto gather_shape = make_shared(shape_of, + opset8::Constant::create(element::i32, Shape{1}, {0}), + opset8::Constant::create(element::i32, Shape{1}, {0})); + auto equal = make_shared(gather_shape, opset8::Constant::create(element::i64, Shape{1}, {1})); + auto if_op = make_shared(equal); + + auto input_then = make_shared(element::i32, PartialShape{-1, 1}); + + auto start = opset8::Constant::create(element::i32, Shape{1}, {3}); + auto stop = opset8::Constant::create(element::i32, Shape{1}, {4}); + auto step = opset8::Constant::create(element::i32, Shape{1}, {1}); + auto slice = make_shared(input_then, start, stop, step); + + auto then_op_result = make_shared(slice); + auto body_then_function = make_shared(NodeVector{then_op_result}, ParameterVector{input_then}); + + auto input_else = make_shared(element::i32, PartialShape{-1, 1}); + auto reshape = + make_shared(input_else, opset8::Constant::create(element::i32, Shape{1}, {-1}), true); + auto else_op_result = make_shared(reshape); + auto body_else_function = make_shared(NodeVector{else_op_result}, ParameterVector{input_else}); + + if_op->set_then_body(body_then_function); + if_op->set_else_body(body_else_function); + if_op->set_input(gather, input_then, input_else); + + auto result_if = if_op->set_output(then_op_result, else_op_result); + + auto begin = opset8::Constant::create(element::i32, Shape{1}, {3}); + auto end = opset8::Constant::create(element::i32, Shape{1}, {4}); + auto strides = opset8::Constant::create(element::i32, Shape{1}, {1}); + auto ss_node = + make_shared(result_if, begin, end, strides, vector{1, 0}, vector{1, 0}); + + auto data = make_shared(element::f32, PartialShape{-1}); + auto axis = opset8::Constant::create(element::i32, Shape{1}, {0}); + auto target_gather = make_shared(data, ss_node, axis); + + shared_ptr f = make_shared(NodeVector{target_gather}, ParameterVector{boxes, scores, data}); + + pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + + const auto& ops = f->get_ops(); + const auto& gather_it = find(ops.begin(), ops.end(), target_gather); + ASSERT_NE(gather_it, ops.end()); + + const auto& rti = (*gather_it)->get_rt_info(); + const auto& reverse = rti.find("dontReverseIndices"); + ASSERT_NE(reverse, rti.end()); +}