Skip to content

Commit

Permalink
wip: not stable
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Mar 20, 2024
1 parent 31f8cb0 commit b4a33fb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_any_input_predicate():


def test_optional_full_match():
model_abs = ops.abs(AnyInput())
model_input = ops.parameter(PartialShape.dynamic())
model_abs = ops.abs(model_input)
model_relu = ops.relu(model_abs.output(0))

pattern_abs = Optional(["opset13.Abs"])
Expand All @@ -97,14 +98,15 @@ def test_optional_full_match():


def test_optional_half_match():
model_abs = ops.add(AnyInput(), AnyInput())
model_relu = ops.relu(model_abs.output(0))
model_input = ops.parameter(PartialShape.dynamic())
model_relu = ops.relu(model_input)
model_relu1 = ops.relu(model_relu.output(0))

pattern_relu = Optional(["opset13.Relu"])
pattern_relu = Optional(["opset13.Abs"])
pattern_relu1 = ops.relu(pattern_relu.output(0))

matcher = Matcher(pattern_relu1, "FindRelu")
assert matcher.match(model_relu)
assert matcher.match(model_relu1)


def test_optional_one_node():
Expand Down
1 change: 1 addition & 0 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
auto pattern = num_input_values_to_optional == 0 ? std::static_pointer_cast<Pattern>(wrap_node)
: std::static_pointer_cast<Pattern>(std::make_shared<Or>(
OutputVector{wrap_node, input_values_to_optional[0]}));
// bool check = (pattern_value.get_node_shared_ptr()->get_output_size() != 0 && num_input_values_to_optional == 0);

if (matcher->match_value(pattern, graph_value) || (same_type && num_input_values_to_optional == 0)) {
auto& pattern_map = matcher->get_pattern_value_map();
Expand Down
26 changes: 12 additions & 14 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,32 +512,30 @@ TEST(pattern, matching_optional) {

TEST(pattern, optional_full_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
auto model_relu1 = std::make_shared<op::v0::Relu>(model_relu->output(0));

auto pattern_add = ov::pass::pattern::optional<op::v1::Add>();
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_add->output(0));
auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu, model_relu));
ASSERT_TRUE(tm.match(pattern_relu1, model_relu1));
}

TEST(pattern, optional_half_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
auto model_relu1 = std::make_shared<op::v0::Relu>(model_relu->output(0));

auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));
auto pattern_abs = ov::pass::pattern::optional<op::v0::Abs>();
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_abs->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
ASSERT_TRUE(tm.match(pattern_relu, model_relu1));
}

TEST(pattern, optional_testing) {
Expand Down

0 comments on commit b4a33fb

Please sign in to comment.