Skip to content

Commit

Permalink
Fix Or pattern behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Nov 25, 2024
1 parent 287ab98 commit d343535
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/core/src/pattern/op/or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ bool ov::pass::pattern::op::Or::match_value(Matcher* matcher,
auto saved = matcher->start_match();
if (matcher->match_value(input_value, graph_value)) {
auto& pattern_map = matcher->get_pattern_value_map();
pattern_map[input_value.get_node_shared_ptr()] = graph_value;
pattern_map[shared_from_this()] = graph_value;
return saved.finish(true);
}
}
Expand Down
57 changes: 57 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,63 @@ TEST(pattern, optional_match_node_with_single_input) {
}
}

TEST(pattern, or_pattern_points_the_selected_branch) {
using namespace ov::op;
using namespace ov::pass::pattern;

// Graph:
auto model_param = make_shared<v0::Parameter>();
auto model_sigmoid = make_shared<v0::Sigmoid>(model_param);

// Pattern:
auto option_1 = wrap_type<v0::Parameter>();
auto option_2 = wrap_type<v0::Sigmoid>();
auto or_pattern = std::make_shared<pattern::op::Or>(ov::OutputVector{option_1, option_2});

// Test:
TestMatcher matcher;
EXPECT_TRUE(matcher.match(or_pattern, model_sigmoid));

auto pattern_val_mp = matcher.get_pattern_value_map();
EXPECT_NO_THROW(pattern_val_mp.at(or_pattern));

// we expect that Or pattern points to the first node of the selected branch
EXPECT_NE(ov::as_type<v0::Sigmoid>(pattern_val_mp.at(or_pattern).get_node()), nullptr);
}

TEST(pattern, multiple_optionals_in_row) {
using namespace ov::op;
using namespace ov::pass::pattern;

// Graph:
Shape shape{1, 2, 3};
auto model_input_0 = make_shared<v0::Parameter>(element::f32, shape);
auto model_sigmoid = make_shared<v0::Sigmoid>(model_input_0);

// Pattern:
auto in = wrap_type<v0::Parameter>();
auto pattern_convert = optional<v0::Convert>(in);
auto pattern_relu = optional<v0::Relu>(pattern_convert);
auto pattern_sigmoid = wrap_type<v0::Sigmoid>({pattern_relu});

// Test:
TestMatcher matcher;
EXPECT_TRUE(matcher.match(pattern_sigmoid, model_sigmoid));

auto pattern_val_mp = matcher.get_pattern_value_map();

EXPECT_NO_THROW(pattern_val_mp.at(in));
EXPECT_NE(ov::as_type<v0::Parameter>(pattern_val_mp.at(in).get_node()), nullptr);

// as Convert and Relu ops are not present in the graph, so we expect the optional nodes
// do not point to the graph nodes, in other words, the optional nodes are not in the pattern map.
EXPECT_EQ(pattern_val_mp.count(pattern_convert), 0);
EXPECT_EQ(pattern_val_mp.count(pattern_relu), 0);

EXPECT_NO_THROW(pattern_val_mp.at(pattern_sigmoid));
EXPECT_NE(ov::as_type<v0::Sigmoid>(pattern_val_mp.at(pattern_sigmoid).get_node()), nullptr);
}

// match optional nodes with multi input where order in not important
TEST(pattern, optional_match_cumulative_node_with_multi_input) {
Shape shape{1, 2, 3};
Expand Down

0 comments on commit d343535

Please sign in to comment.