Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFORMATIONS] Create python binding for pattern::Optional #23558

Merged
merged 8 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# type: ignore
# flake8: noqa

from openvino._pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput
from openvino._pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput, Optional
from openvino._pyopenvino.passes import (
consumers_count,
has_static_dim,
Expand Down
106 changes: 106 additions & 0 deletions src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <string>

#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand Down Expand Up @@ -482,6 +483,110 @@ static void reg_pattern_any_input(py::module m) {
});
}

static void reg_pattern_optional(py::module m) {
py::class_<ov::pass::pattern::op::Optional, std::shared_ptr<ov::pass::pattern::op::Optional>, ov::Node>
optional_type(m, "Optional");
optional_type.doc() = "openvino.runtime.passes.Optional wraps ov::pass::pattern::op::Optional";

optional_type.def(py::init([](const std::vector<std::string>& type_names) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names));
}),
py::arg("type_name"),
R"(
Create Optional with the given node type.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
)");

optional_type.def(py::init([](const std::vector<std::string>& type_names, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), pred);
}),
py::arg("type_names"),
py::arg("pred"),
CuriousPanCake marked this conversation as resolved.
Show resolved Hide resolved
R"(
Create Optional with the given node type and predicate.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]

:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");

optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, pred);
}),
py::arg("type_names"),
py::arg("input"),
py::arg("pred"),
CuriousPanCake marked this conversation as resolved.
Show resolved Hide resolved
R"(
Create Optional with the given node type, input node and predicate.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]

:param input: input node's output.
:type input: openvino.runtime.Output

:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");

optional_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, nullptr);
}),
py::arg("type_names"),
py::arg("input"),
R"(
Create Optional with the given node type and input node.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]

:param input: input node's output.
:type input: openvino.runtime.Output
)");

optional_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, nullptr);
}),
py::arg("type_names"),
py::arg("input"),
R"(
Create Optional with the given node type and input node.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]

:param input: input node
:type input: openvino.runtime.Node
)");

optional_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::Optional>(get_types(type_names), input, pred);
}),
py::arg("type_names"),
py::arg("input"),
py::arg("pred"),
CuriousPanCake marked this conversation as resolved.
Show resolved Hide resolved
R"(
Create Optional with the given node type, input node and predicate.

:param type_names: node type. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]

:param input: input node
:type input: openvino.runtime.Node

:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");

optional_type.def("__repr__", [](const ov::pass::pattern::op::Optional& self) {
return Common::get_simple_repr(self);
});
}

inline void reg_predicates(py::module m) {
m.def("consumers_count", &ov::pass::pattern::consumers_count);
m.def("has_static_dim", &ov::pass::pattern::has_static_dim);
Expand All @@ -497,5 +602,6 @@ void reg_passes_pattern_ops(py::module m) {
reg_pattern_any_input(m);
reg_pattern_wrap_type(m);
reg_pattern_or(m);
reg_pattern_optional(m);
reg_predicates(m);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from openvino import PartialShape
from openvino.runtime import opset13 as ops
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput, Optional
from openvino.runtime.passes import (
consumers_count,
has_static_dim,
Expand Down Expand Up @@ -85,6 +85,94 @@ def test_any_input_predicate():
assert not matcher.match(slope)


def test_optional_full_match():
akuporos marked this conversation as resolved.
Show resolved Hide resolved
itikhono marked this conversation as resolved.
Show resolved Hide resolved
model_abs = ops.abs(AnyInput())
model_relu = ops.relu(model_abs.output(0))

pattern_abs = Optional(["opset13.Abs"])
pattern_relu = ops.relu(pattern_abs.output(0))

matcher = Matcher(pattern_relu, "FindRelu")
assert matcher.match(model_relu)


def test_optional_half_match():
model_abs = ops.add(AnyInput(), AnyInput())
model_relu = ops.relu(model_abs.output(0))

pattern_relu = Optional(["opset13.Relu"])
pattern_relu1 = ops.relu(pattern_relu.output(0))

matcher = Matcher(pattern_relu1, "FindRelu")
assert matcher.match(model_relu)


def test_optional_one_node():
model_input = ops.parameter(PartialShape.dynamic())
model_relu = ops.relu(model_input)
model_abs = ops.abs(model_input)

assert Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(model_relu)
assert not Matcher(Optional(["opset13.Abs"]), "OneNodeTest").match(model_relu)

assert not Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(model_abs)

assert Matcher(Optional(["opset13.Parameter"]), "OneNodeTest").match(ops.parameter(PartialShape.dynamic()))
assert not Matcher(Optional(["opset13.Relu"]), "OneNodeTest").match(ops.parameter(PartialShape.dynamic()))


def test_optional_predicate():
model_input = ops.parameter(PartialShape.dynamic())
model_add = ops.add(model_input, model_input)
model_relu = ops.relu(model_add.output(0))
model_abs = ops.abs(model_add.output(0))

assert Matcher(Optional(["opset13.Relu"], lambda x: True), "TestInputPredicate").match(model_relu)
assert not Matcher(Optional(["opset13.Relu"], lambda x: False), "TestInputPredicate").match(model_relu)
assert Matcher(Optional(["opset13.Add"], consumers_count(2)), "FindPredicate").match(model_add)
assert not Matcher(Optional(["opset13.Add"], consumers_count(1)), "FindPredicate").match(model_add)
assert Matcher(Optional(["opset13.Abs", "opset13.Result"], consumers_count(0)), "FindPredicate").match(model_abs)


def test_optional_with_input():
model_input = ops.parameter(PartialShape.dynamic())
model_add = ops.add(model_input, model_input)
model_relu = ops.relu(model_add.output(0))

assert Matcher(Optional(["opset13.Relu"], model_add.output(0)), "TestInput").match(model_relu)
assert not Matcher(Optional(["opset13.Cos"], model_add.output(0)), "TestInput").match(model_relu)


def test_optional_with_input_and_predicate():
model_input = ops.parameter(PartialShape.dynamic())
model_add = ops.add(model_input, model_input)
model_relu = ops.relu(model_add.output(0))

pattern_add = ops.add(AnyInput(), AnyInput())

assert Matcher(Optional(["opset13.Relu"], pattern_add.output(0), lambda x: True), "TestInputPredicate").match(model_relu)
assert not Matcher(Optional(["opset13.Relu"], pattern_add.output(0), lambda x: False), "TestInputPredicate").match(model_relu)


def test_optional_with_input_node():
model_input = ops.parameter(PartialShape.dynamic())
model_add = ops.add(model_input, model_input)
model_relu = ops.relu(model_add.output(0))

assert Matcher(Optional(["opset13.Relu"], model_add), "TestInputNode").match(model_relu)
assert not Matcher(Optional(["opset13.Cos"], model_add), "TestInputNode").match(model_relu)


def test_optional_with_input_node_and_predicate():
model_input = ops.parameter(PartialShape.dynamic())
model_add = ops.add(model_input, model_input)
model_relu = ops.relu(model_add.output(0))

assert Matcher(Optional(["opset13.Relu"], model_add, lambda x: True), "TestInputNodePredicate").match(model_relu)
assert not Matcher(Optional(["opset13.Relu"], model_add, lambda x: False), "TestInputNodePredicate").match(model_relu)
assert not Matcher(Optional(["opset13.Cos"], model_add, lambda x: True), "TestInputNodePredicate").match(model_relu)


def test_all_predicates():
static_param = ops.parameter(PartialShape([1, 3, 22, 22]), np.float32)
dynamic_param = ops.parameter(PartialShape([-1, 6]), np.compat.long)
Expand Down
3 changes: 2 additions & 1 deletion src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
bool same_type = pattern_value.get_element_type() == graph_value.get_element_type();
itikhono marked this conversation as resolved.
Show resolved Hide resolved
auto wrap_node = std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);

// Either continue using the WrapType if there're no inputs to it or create an Or node,
Expand All @@ -53,7 +54,7 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
: std::static_pointer_cast<Pattern>(std::make_shared<Or>(
OutputVector{wrap_node, input_values_to_optional[0]}));

if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) {
if (matcher->match_value(pattern, graph_value) || (same_type && num_input_values_to_optional == 0)) {
auto& pattern_map = matcher->get_pattern_value_map();
if (pattern_map.count(wrap_node)) {
pattern_map[shared_from_this()] = graph_value;
Expand Down
17 changes: 17 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,23 @@ TEST(pattern, optional_testing) {
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
}

TEST(pattern, optional_one_node) {
Shape shape{};
auto model_input = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_relu = std::make_shared<op::v0::Relu>(model_input);
auto model_abs = std::make_shared<op::v0::Abs>(model_input);

TestMatcher tm;

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_relu));
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Abs>(), model_relu));

ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_abs));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Parameter>(), model_input));
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(), model_input));
}

TEST(pattern, mean) {
// construct mean
TestMatcher n;
Expand Down
Loading