Skip to content

Commit

Permalink
NON_PATTERN_NODER_WITH_TYPE attribute for pattern matching is introduced
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 17, 2023
1 parent d7b342e commit e755b56
Show file tree
Hide file tree
Showing 6 changed files with 549 additions and 484 deletions.
15 changes: 11 additions & 4 deletions nncf/common/graph/graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from nncf.common.graph.patterns import GraphPattern

ATTRS_TO_SKIP = [GraphPattern.LABEL_ATTR, GraphPattern.NON_PATTERN_NODE_WITH_TYPE]


def _are_nodes_matched(node_1, node_2) -> bool:
for attr in node_2:
if attr == GraphPattern.LABEL_ATTR:
if attr in ATTRS_TO_SKIP:
continue
if attr == GraphPattern.METATYPE_ATTR:
# GraphPattern.ANY_PATTERN_NODE_TYPE and GraphPattern.NON_PATTERN_NODE_TYPE
Expand Down Expand Up @@ -103,7 +105,8 @@ def _is_subgraph_matching_strict(graph: nx.DiGraph, pattern: nx.DiGraph, subgrap

def _copy_subgraph_excluding_non_pattern_node(subgraph: Dict[str, str], pattern_graph: GraphPattern) -> Dict[str, str]:
"""
Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE.
Copies a matching subgraph excluding the nodes having GraphPattern.NON_PATTERN_NODE_TYPE
or GraphPattern.NON_PATTERN_NODE_WITH_TYPE.
:param subgraph: Subgraph
:param pattern_graph: A graph consists of patterns to match.
Expand All @@ -113,8 +116,12 @@ def _copy_subgraph_excluding_non_pattern_node(subgraph: Dict[str, str], pattern_
for node_from_graph, node_from_pattern in subgraph.items():
pattern_node = pattern_graph.graph.nodes[node_from_pattern]
pattern_node_types = pattern_node.get(GraphPattern.METATYPE_ATTR, [])
if GraphPattern.NON_PATTERN_NODE_TYPE not in pattern_node_types:
output[node_from_graph] = node_from_pattern
if GraphPattern.NON_PATTERN_NODE_TYPE in pattern_node_types:
continue
if pattern_node.get(GraphPattern.NON_PATTERN_NODE_WITH_TYPE, False):
continue
output[node_from_graph] = node_from_pattern

return output


Expand Down
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class GraphPattern:
NODE_TYPE_ATTR = "metatype"
ANY_PATTERN_NODE_TYPE = "ANY_PATTERN_NODE"
NON_PATTERN_NODE_TYPE = "NON_PATTERN_NODE"
NON_PATTERN_NODE_WITH_TYPE = "NON_PATTERN_NODE_WITH_TYPE"

def __init__(self):
self._graph = nx.DiGraph()
Expand Down
12 changes: 10 additions & 2 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def create_se_block() -> GraphPattern:
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
)
reduce_mean_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "REDUCE_MEAN", GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype}
**{
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype,
GraphPattern.NON_PATTERN_NODE_WITH_TYPE: True,
}
)
linear_node_1 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
Expand All @@ -138,7 +142,11 @@ def create_se_block() -> GraphPattern:
**{GraphPattern.LABEL_ATTR: "SIGMOID", GraphPattern.METATYPE_ATTR: om.OVSigmoidMetatype}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
**{
GraphPattern.LABEL_ATTR: "MULTIPLY",
GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype,
GraphPattern.NON_PATTERN_NODE_WITH_TYPE: True,
}
)

pattern.add_edge(any_node, reduce_mean_node)
Expand Down
14 changes: 10 additions & 4 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,17 @@ def create_se_block() -> GraphPattern:
MEAN_OPERATIONS = {
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: ["avg_pool2d", "adaptive_avg_pool2d", "avg_pool3d", "adaptive_avg_pool3d", "mean"],
GraphPattern.NON_PATTERN_NODE_WITH_TYPE: True,
}
SYGMOID_OPERATIONS = {
GraphPattern.LABEL_ATTR: "SIGMOID",
GraphPattern.METATYPE_ATTR: ["sigmoid", "hardsigmoid"],
}
MUL_OPERATION = {
GraphPattern.LABEL_ATTR: "MUL",
GraphPattern.METATYPE_ATTR: "__mul__",
GraphPattern.NON_PATTERN_NODE_WITH_TYPE: True,
}

def get_se_block_pattern() -> GraphPattern:
pattern = GraphPattern()
Expand All @@ -112,7 +118,7 @@ def get_se_block_pattern() -> GraphPattern:
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(label="MUL", type="__mul__")
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
Expand All @@ -133,7 +139,7 @@ def get_se_block_with_bias_pattern() -> GraphPattern:
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(label="MUL", type="__mul__")
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
Expand Down Expand Up @@ -161,7 +167,7 @@ def get_se_block_with_reshape() -> GraphPattern:
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
Expand All @@ -186,7 +192,7 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern:
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")
multiply_node = pattern.add_node(**MUL_OPERATION)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
Expand Down
25 changes: 25 additions & 0 deletions tests/common/graph/test_graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,28 @@ def test_not_match_edges_inside_pattern():
pattern.add_edge(node_1, node_3)
matches = find_subgraphs_matching_pattern(ref_graph, pattern)
assert matches == [["1", "2", "3"]]


def test_non_pattern_graph_with_type():
for match in [False, True]:
ref_graph = nx.DiGraph()
ref_graph.add_node("0", **{GraphPattern.METATYPE_ATTR: "0"})
ref_graph.add_node("1", **{GraphPattern.METATYPE_ATTR: "a" if match else "0"})
ref_graph.add_node("2", **{GraphPattern.METATYPE_ATTR: "b"})
ref_graph.add_node("3", **{GraphPattern.METATYPE_ATTR: "c"})
ref_graph.add_edge("0", "1")
ref_graph.add_edge("1", "2")
ref_graph.add_edge("2", "3")

pattern = GraphPattern()
node_1 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "a", GraphPattern.NON_PATTERN_NODE_WITH_TYPE: True})
node_2 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "b"})
node_3 = pattern.add_node(**{GraphPattern.METATYPE_ATTR: "c"})
pattern.add_edge(node_1, node_2)
pattern.add_edge(node_2, node_3)

matches = find_subgraphs_matching_pattern(ref_graph, pattern)
if not match:
assert not matches
else:
assert matches == [["2", "3"]]
Loading

0 comments on commit e755b56

Please sign in to comment.