Skip to content

Commit

Permalink
[Transformations] Added If operation to NMS path propagation for igno…
Browse files Browse the repository at this point in the history
…re negative indices in Gather
  • Loading branch information
Lyamin-Roman committed Mar 18, 2024
1 parent 7f94e35 commit 8f39532
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/non_max_suppression.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
Expand Down Expand Up @@ -60,7 +61,8 @@ class PropagateNMSPath : public pass::MatcherPass {
ov::op::v1::VariadicSplit,
op::util::GatherBase,
ov::op::v0::Concat,
ov::op::v0::Convert>();
ov::op::v0::Convert,
ov::op::v8::If>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
const auto& inputs = node->input_values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,62 @@ TEST(TransformationTests, test_convert_to_unsigned_nms_gather_3) {
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(count_ops_of_type<ov::op::v0::Convert>(f), 0);
}

TEST(TransformationTests, test_convert_to_unsigned_nms_gather_with_if_condition) {
auto boxes = make_shared<opset8::Parameter>(element::f32, PartialShape{1, -1, 4});
auto scores = make_shared<opset8::Parameter>(element::f32, PartialShape{1, 1, -1});
auto nms = make_shared<opset8::NonMaxSuppression>(boxes, scores);

auto gather = make_shared<opset8::Gather>(nms->output(0),
opset8::Constant::create(element::i32, Shape{1}, {2}),
opset8::Constant::create(element::i32, Shape{1}, {0}));

auto shape_of = make_shared<opset8::ShapeOf>(gather);
auto gather_shape = make_shared<opset8::Gather>(shape_of,
opset8::Constant::create(element::i32, Shape{1}, {0}),
opset8::Constant::create(element::i32, Shape{1}, {0}));
auto equal = make_shared<opset8::Equal>(gather_shape, opset8::Constant::create(element::i64, Shape{1}, {1}));
auto if_op = make_shared<opset8::If>(equal);

auto input_then = std::make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});
auto then_op_result = std::make_shared<op::v0::Result>(gather);
auto body_then_function =
std::make_shared<Model>(NodeVector{then_op_result}, ParameterVector{input_then, boxes, scores});

auto input_else = std::make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});
auto else_op_result = std::make_shared<op::v0::Result>(gather);
auto body_else_function =
std::make_shared<Model>(NodeVector{else_op_result}, ParameterVector{input_else, boxes, scores});

if_op->set_then_body(body_then_function);
if_op->set_else_body(body_else_function);
if_op->set_input(gather, input_then, input_else);

auto result_if = if_op->set_output(then_op_result, else_op_result);

auto begin = opset8::Constant::create(element::i32, Shape{1}, {3});
auto end = opset8::Constant::create(element::i32, Shape{1}, {4});
auto strides = opset8::Constant::create(element::i32, Shape{1}, {1});
auto ss_node =
make_shared<opset8::StridedSlice>(result_if, begin, end, strides, vector<int64_t>{1, 0}, vector<int64_t>{1, 0});

auto data = make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto axis = opset8::Constant::create(element::i32, Shape{1}, {0});
auto target_gather = make_shared<opset8::Gather>(data, ss_node, axis);

shared_ptr<Model> f = make_shared<Model>(NodeVector{target_gather}, ParameterVector{boxes, scores, data});

pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertNmsGatherPathToUnsigned>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));

const auto& ops = f->get_ops();
const auto& gather_it = std::find(ops.begin(), ops.end(), target_gather);
ASSERT_NE(gather_it, ops.end());

const auto& rti = (*gather_it)->get_rt_info();
const auto& reverse = rti.find("dontReverseIndices");
ASSERT_NE(reverse, rti.end());
}

0 comments on commit 8f39532

Please sign in to comment.