Skip to content

Commit

Permalink
[Transformations] Added If operation to NMS path propagation for igno…
Browse files Browse the repository at this point in the history
…re negative indices in Gather
  • Loading branch information
Lyamin-Roman committed Mar 20, 2024
1 parent 7d5e4af commit 400e75f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/non_max_suppression.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
Expand All @@ -19,8 +20,10 @@
#include "openvino/op/util/broadcast_base.hpp"
#include "openvino/op/util/gather_base.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/nms_selected_indices.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"

using namespace std;

Expand Down Expand Up @@ -60,14 +63,51 @@ class PropagateNMSPath : public pass::MatcherPass {
ov::op::v1::VariadicSplit,
op::util::GatherBase,
ov::op::v0::Concat,
ov::op::v0::Convert>();
ov::op::v0::Convert,
ov::op::v8::If>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto propagate_path = [](const ov::OutputVector& input_nodes, ov::Node* target_node) {
if (any_of(input_nodes.begin(), input_nodes.end(), [](const Output<Node>& output) {
return ov::has_nms_selected_indices(output.get_node());
})) {
ov::set_nms_selected_indices(target_node);
}
};
auto handle_params =
[&propagate_path](std::shared_ptr<ov::op::util::MultiSubGraphOp> if_node, std::shared_ptr<ov::Model> body, int body_index) {
const auto& params = body->get_parameters();
for (auto input_desc : if_node->get_input_descriptions(body_index)) {
auto param = params[input_desc->m_body_parameter_index];
auto input_node = if_node->input(input_desc->m_input_index).get_source_output();
propagate_path({input_node}, param.get());
}
};
auto handle_results =
[&propagate_path](std::shared_ptr<ov::op::util::MultiSubGraphOp> if_node, std::shared_ptr<ov::Model> body, int body_index) {
const auto& results = body->get_results();
for (auto output_desc : if_node->get_output_descriptions(body_index)) {
auto result = results[output_desc->m_body_value_index];
const auto& result_inputs = result->input_values();
auto output_node = if_node->output(output_desc->m_output_index).get_node();
propagate_path(result_inputs, output_node);
}
};

auto node = m.get_match_root();
const auto& inputs = node->input_values();
if (any_of(inputs.begin(), inputs.end(), [](const Output<Node>& output) {
return ov::has_nms_selected_indices(output.get_node());
})) {
ov::set_nms_selected_indices(node.get());
if (ov::is_type<ov::op::util::MultiSubGraphOp>(node)) {
auto multi_subgraph_op = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(node);
const auto& models = multi_subgraph_op->get_functions();

for (size_t body_idx = 0; body_idx < models.size(); ++body_idx) {
handle_params(multi_subgraph_op, models[body_idx], static_cast<int>(body_idx));
ov::pass::Manager manager;
manager.register_pass<ov::pass::PropagateNMSPath>();
manager.run_passes(models[body_idx]);
handle_results(multi_subgraph_op, models[body_idx], static_cast<int>(body_idx));
}
} else {
const auto& inputs = node->input_values();
propagate_path(inputs, node.get());
}
return false;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,68 @@ TEST(TransformationTests, test_convert_to_unsigned_nms_gather_3) {
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(count_ops_of_type<ov::op::v0::Convert>(f), 0);
}

TEST(TransformationTests, test_convert_to_unsigned_nms_gather_with_if_condition) {
auto boxes = make_shared<opset8::Parameter>(element::f32, PartialShape{1, -1, 4});
auto scores = make_shared<opset8::Parameter>(element::f32, PartialShape{1, 1, -1});
auto nms = make_shared<opset8::NonMaxSuppression>(boxes, scores);

auto gather = make_shared<opset8::Gather>(nms->output(0),
opset8::Constant::create(element::i32, Shape{1}, {2}),
opset8::Constant::create(element::i32, Shape{1}, {0}));

auto shape_of = make_shared<opset8::ShapeOf>(gather);
auto gather_shape = make_shared<opset8::Gather>(shape_of,
opset8::Constant::create(element::i32, Shape{1}, {0}),
opset8::Constant::create(element::i32, Shape{1}, {0}));
auto equal = make_shared<opset8::Equal>(gather_shape, opset8::Constant::create(element::i64, Shape{1}, {1}));
auto if_op = make_shared<opset8::If>(equal);

auto input_then = make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});

auto start = opset8::Constant::create(element::i32, Shape{1}, {3});
auto stop = opset8::Constant::create(element::i32, Shape{1}, {4});
auto step = opset8::Constant::create(element::i32, Shape{1}, {1});
auto slice = make_shared<opset8::Slice>(input_then, start, stop, step);

auto then_op_result = make_shared<op::v0::Result>(slice);
auto body_then_function = make_shared<Model>(NodeVector{then_op_result}, ParameterVector{input_then});

auto input_else = make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});
auto reshape =
make_shared<opset8::Reshape>(input_else, opset8::Constant::create(element::i32, Shape{1}, {-1}), true);
auto else_op_result = make_shared<op::v0::Result>(reshape);
auto body_else_function = make_shared<Model>(NodeVector{else_op_result}, ParameterVector{input_else});

if_op->set_then_body(body_then_function);
if_op->set_else_body(body_else_function);
if_op->set_input(gather, input_then, input_else);

auto result_if = if_op->set_output(then_op_result, else_op_result);

auto begin = opset8::Constant::create(element::i32, Shape{1}, {3});
auto end = opset8::Constant::create(element::i32, Shape{1}, {4});
auto strides = opset8::Constant::create(element::i32, Shape{1}, {1});
auto ss_node =
make_shared<opset8::StridedSlice>(result_if, begin, end, strides, vector<int64_t>{1, 0}, vector<int64_t>{1, 0});

auto data = make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto axis = opset8::Constant::create(element::i32, Shape{1}, {0});
auto target_gather = make_shared<opset8::Gather>(data, ss_node, axis);

shared_ptr<Model> f = make_shared<Model>(NodeVector{target_gather}, ParameterVector{boxes, scores, data});

pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertNmsGatherPathToUnsigned>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));

const auto& ops = f->get_ops();
const auto& gather_it = find(ops.begin(), ops.end(), target_gather);
ASSERT_NE(gather_it, ops.end());

const auto& rti = (*gather_it)->get_rt_info();
const auto& reverse = rti.find("dontReverseIndices");
ASSERT_NE(reverse, rti.end());
}

0 comments on commit 400e75f

Please sign in to comment.