diff --git a/src/core/src/pattern/op/optional.cpp b/src/core/src/pattern/op/optional.cpp index 0234f84ca1e9f3..83f231b711d37d 100644 --- a/src/core/src/pattern/op/optional.cpp +++ b/src/core/src/pattern/op/optional.cpp @@ -18,7 +18,8 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher, // Turn the Optional node into WrapType node to create a case where the Optional node is present ov::OutputVector input_values_to_optional = input_values(); size_t num_input_values_to_optional = input_values_to_optional.size(); - auto wrap_node = std::make_shared(optional_types, m_predicate, input_values_to_optional); + auto wrap_node = + std::make_shared(optional_types, m_predicate, input_values_to_optional); // Add the newly created WrapType node to the list containing its inputs and create an Or node with the list input_values_to_optional.push_back(wrap_node); diff --git a/src/core/tests/pattern.cpp b/src/core/tests/pattern.cpp index fdd9a783f91741..d585bbc3ba703e 100644 --- a/src/core/tests/pattern.cpp +++ b/src/core/tests/pattern.cpp @@ -508,6 +508,63 @@ TEST(pattern, matching_optional) { std::make_shared(c))); } +TEST(pattern, optional_full_match) { + Shape shape{}; + auto model_input1 = std::make_shared(element::i32, shape); + auto model_input2 = std::make_shared(element::i32, shape); + auto model_add = std::make_shared(model_input1->output(0), model_input2->output(0)); + auto model_relu = std::make_shared(model_add->output(0)); + + auto pattern_add = ov::pass::pattern::optional(); + auto pattern_relu = std::make_shared(pattern_add->output(0)); + + TestMatcher tm; + + ASSERT_TRUE(tm.match(pattern_relu, model_relu)); +} + +TEST(pattern, optional_half_match) { + Shape shape{}; + auto model_input1 = std::make_shared(element::i32, shape); + auto model_input2 = std::make_shared(element::i32, shape); + auto model_add = std::make_shared(model_input1->output(0), model_input2->output(0)); + auto model_relu = std::make_shared(model_add->output(0)); + + auto pattern_relu = ov::pass::pattern::optional(); + auto pattern_relu1 = std::make_shared(pattern_relu->output(0)); + + TestMatcher tm; + + ASSERT_TRUE(tm.match(pattern_relu1, model_relu)); +} + +TEST(pattern, optional_new_test) { + Shape shape{}; + auto model_input1 = std::make_shared(element::i32, shape); + auto model_input2 = std::make_shared(element::i32, shape); + auto model_add = std::make_shared(model_input1->output(0), model_input2->output(0)); + auto model_relu = std::make_shared(model_add->output(0)); + auto model_abs = std::make_shared(model_add->output(0)); + + TestMatcher tm; + + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_abs), std::make_shared(model_abs))); + ASSERT_FALSE(tm.match(ov::pass::pattern::optional(model_abs), std::make_shared(model_abs))); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_abs), std::make_shared(model_abs))); + + ASSERT_FALSE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); + + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_relu), std::make_shared(std::make_shared(model_add)))); + + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_relu), std::make_shared(std::make_shared(model_add)))); +} + TEST(pattern, mean) { // construct mean TestMatcher n;