Skip to content

Commit

Permalink
SE_BLOCK pattern moved from HW to IGNORED pattres
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 16, 2023
1 parent 7506e03 commit d7b342e
Show file tree
Hide file tree
Showing 15 changed files with 3,486 additions and 3,045 deletions.
32 changes: 7 additions & 25 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.
from typing import Callable, Dict, Optional, Union

from nncf.common.graph.patterns.patterns import AlgorithmType
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.patterns.patterns import HWFusedPatternNames
from nncf.common.graph.patterns.patterns import IgnoredPatternNames
Expand Down Expand Up @@ -75,10 +74,7 @@ def _get_backend_ignored_patterns_map(

@staticmethod
def _filter_patterns(
patterns_to_filter: Dict[PatternNames, Callable[[], GraphPattern]],
device: TargetDevice,
model_type: ModelType,
algorithm: AlgorithmType,
patterns_to_filter: Dict[PatternNames, Callable[[], GraphPattern]], device: TargetDevice, model_type: ModelType
) -> Dict[PatternNames, Callable[[], GraphPattern]]:
"""
Returns all patterns from patterns_to_filter that are satisfited device and model_type parameters.
Expand All @@ -92,13 +88,9 @@ def _filter_patterns(
for pattern_desc, pattern_creator in patterns_to_filter.items():
pattern_desc_devices = pattern_desc.value.devices
pattern_desc_model_types = pattern_desc.value.model_types
pattern_ignored_algorithms = pattern_desc.value.ignored_algorithms
devices_condition = pattern_desc_devices is None or device in pattern_desc_devices
model_types_condition = pattern_desc_model_types is None or model_type in pattern_desc_model_types
ignored_algorithms_condition = (
pattern_ignored_algorithms is None or algorithm not in pattern_ignored_algorithms
)
if devices_condition and model_types_condition and ignored_algorithms_condition:
if devices_condition and model_types_condition:
filtered_patterns[pattern_desc] = pattern_creator
return filtered_patterns

Expand All @@ -107,29 +99,24 @@ def _get_full_pattern_graph(
backend_patterns_map: Dict[PatternNames, Callable[[], GraphPattern]],
device: TargetDevice,
model_type: ModelType,
algorithm: Optional[AlgorithmType],
) -> GraphPattern:
"""
Filters patterns and returns GraphPattern with registered filtered patterns.
:param backend_patterns_map: Dictionary with the PatternNames instance as keys and creator function as a value.
:param device: TargetDevice instance.
:param model_type: ModelType instance.
:param algorithm: AlgorithmType instance.
:return: Completed GraphPattern based on the backend, device & model_type.
"""
filtered_patterns = PatternsManager._filter_patterns(backend_patterns_map, device, model_type, algorithm)
filtered_patterns = PatternsManager._filter_patterns(backend_patterns_map, device, model_type)
patterns = Patterns()
for pattern_desc, pattern_creator in filtered_patterns.items():
patterns.register(pattern_creator(), pattern_desc.value.name)
return patterns.get_full_pattern_graph()

@staticmethod
def get_full_hw_pattern_graph(
backend: BackendType,
device: TargetDevice,
model_type: Optional[ModelType] = None,
algorithm: Optional[AlgorithmType] = None,
backend: BackendType, device: TargetDevice, model_type: Optional[ModelType] = None
) -> GraphPattern:
"""
Returns a GraphPattern containing all registered hardware patterns specifically
Expand All @@ -138,18 +125,14 @@ def get_full_hw_pattern_graph(
:param backend: BackendType instance.
:param device: TargetDevice instance.
:param model_type: ModelType instance.
:param algorithm: AlgorithmType instance.
:return: Completed GraphPattern based on the backend, device & model_type.
"""
backend_patterns_map = PatternsManager._get_backend_hw_patterns_map(backend)
return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type, algorithm)
return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type)

@staticmethod
def get_full_ignored_pattern_graph(
backend: BackendType,
device: TargetDevice,
model_type: Optional[ModelType] = None,
algorithm: Optional[AlgorithmType] = None,
backend: BackendType, device: TargetDevice, model_type: Optional[ModelType] = None
) -> GraphPattern:
"""
Returns a GraphPattern containing all registered ignored patterns specifically
Expand All @@ -158,8 +141,7 @@ def get_full_ignored_pattern_graph(
:param backend: BackendType instance.
:param device: TargetDevice instance.
:param model_type: ModelType instance.
:param algorithm: AlgorithmType instance.
:return: Completed GraphPattern with registered value based on the backend, device & model_type.
"""
backend_patterns_map = PatternsManager._get_backend_ignored_patterns_map(backend)
return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type, algorithm)
return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type)
18 changes: 5 additions & 13 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,6 @@ def merge_two_types_of_operations(first_op: Dict, second_op: Dict, label: str) -
raise RuntimeError("Incorrect dicts of operations")


class AlgorithmType(Enum):
"""
Algorithm type which is used by the pattern manager to
provide patterns specific to the target algorithm type.
"""

QUANTIZATION = "QUANTIZATION"
NAS = "NAS"


@dataclass
class PatternDesc:
"""
Expand All @@ -286,8 +276,7 @@ class PatternDesc:

name: str
devices: Optional[List[TargetDevice]] = None
model_types: Optional[List[ModelType]] = None
ignored_algorithms: Optional[List[AlgorithmType]] = None
model_types: Optional[List[TargetDevice]] = None


class HWFusedPatternNames(Enum):
Expand All @@ -298,6 +287,8 @@ class HWFusedPatternNames(Enum):

# ATOMIC OPERATIONS
L2_NORM = PatternDesc("l2_norm")
MVN = PatternDesc("mvn")
GELU = PatternDesc("gelu")

# BLOCK PATTERNS
ADD_SCALE_SHIFT_OUTPUT = PatternDesc("add_scale_shift_output")
Expand All @@ -307,7 +298,6 @@ class HWFusedPatternNames(Enum):
NORMALIZE_L2_MULTIPLY = PatternDesc("normalize_l2_multiply")
SCALE_SHIFT = PatternDesc("scale_shift")
SHIFT_SCALE = PatternDesc("shift_scale")
SE_BLOCK = PatternDesc("se_block", ignored_algorithms=[AlgorithmType.NAS])
SOFTMAX_DIV = PatternDesc("softmax_div")

# ACTIVATIONS
Expand Down Expand Up @@ -349,6 +339,7 @@ class HWFusedPatternNames(Enum):
LINEAR_ACTIVATIONS_BATCH_NORM = PatternDesc("linear_activations_batch_norm")
LINEAR_ACTIVATIONS_SCALE_SHIFT = PatternDesc("linear_activations_scale_shift")
LINEAR_ARITHMETIC = PatternDesc("linear_arithmetic")
LINEAR_SHIFT_SCALE = PatternDesc("linear_shift_scale")
LINEAR_ARITHMETIC_ACTIVATIONS = PatternDesc("linear_arithmetic_activations")
# Found in PicoDet models
LINEAR_ARITHMETIC_ACTIVATIONS_ARITHMETIC = PatternDesc("linear_arithmetic_activations_arithmetic")
Expand Down Expand Up @@ -404,5 +395,6 @@ class IgnoredPatternNames(Enum):
model_types=[ModelType.TRANSFORMER],
devices=[TargetDevice.ANY, TargetDevice.CPU, TargetDevice.GPU, TargetDevice.VPU],
)
SE_BLOCK = PatternDesc("se_block")
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
EQUAL_LOGICALNOT = PatternDesc("equal_logicalnot")
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from nncf.common.graph.graph_matching import find_subgraphs_matching_pattern
from nncf.common.graph.patterns.manager import PatternsManager
from nncf.common.graph.patterns.manager import TargetDevice
from nncf.common.graph.patterns.patterns import AlgorithmType
from nncf.common.utils.backend import BackendType
from nncf.common.utils.dot_file_rw import write_dot_graph
from nncf.torch.graph.graph import PTNNCFGraph
Expand Down Expand Up @@ -258,9 +257,7 @@ def get_merged_original_graph_with_pattern(orig_graph: nx.DiGraph, hw_fused_ops:
if not hw_fused_ops:
return merged_graph
# pylint: disable=protected-access
pattern_fusing_graph = PatternsManager.get_full_hw_pattern_graph(
backend=BackendType.TORCH, device=TargetDevice.ANY, algorithm=AlgorithmType.NAS
)
pattern_fusing_graph = PatternsManager.get_full_hw_pattern_graph(backend=BackendType.TORCH, device=TargetDevice.ANY)
matches = find_subgraphs_matching_pattern(orig_graph, pattern_fusing_graph)
nx.set_node_attributes(merged_graph, False, SearchGraph.IS_DUMMY_NODE_ATTR)
nx.set_node_attributes(merged_graph, False, SearchGraph.IS_MERGED_NODE_ATTR)
Expand Down
42 changes: 0 additions & 42 deletions nncf/openvino/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,48 +147,6 @@ def create_shift_scale() -> GraphPattern:
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
)
reduce_mean_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "REDUCE_MEAN", GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype}
)
linear_node_1 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_1 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_1 = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RELU, PRELU, SWISH",
GraphPattern.METATYPE_ATTR: [om.OVReluMetatype, om.OVPReluMetatype, om.OVSwishMetatype],
}
)
linear_node_2 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_2 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_2 = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SIGMOID", GraphPattern.METATYPE_ATTR: om.OVSigmoidMetatype}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SOFTMAX_DIV)
def create_softmax_div() -> GraphPattern:
pattern = GraphPattern()
Expand Down
43 changes: 43 additions & 0 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nncf.common.graph.patterns.patterns import IgnoredPatternNames
from nncf.common.utils.registry import Registry
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.metatypes.groups import LINEAR_OPERATIONS

OPENVINO_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS")

Expand Down Expand Up @@ -108,3 +109,45 @@ def create_equal_logicalnot() -> GraphPattern:

pattern.add_edge(equal_node, logical_not_node)
return pattern


@OPENVINO_IGNORED_PATTERNS.register(IgnoredPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "ANY", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE}
)
reduce_mean_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "REDUCE_MEAN", GraphPattern.METATYPE_ATTR: om.OVReduceMeanMetatype}
)
linear_node_1 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_1 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_1 = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RELU, PRELU, SWISH",
GraphPattern.METATYPE_ATTR: [om.OVReluMetatype, om.OVPReluMetatype, om.OVSwishMetatype],
}
)
linear_node_2 = pattern.add_node(
**{GraphPattern.METATYPE_ATTR: LINEAR_OPERATIONS, GraphPattern.LABEL_ATTR: "LINEAR"}
)
add_node_2 = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD_BIAS", GraphPattern.METATYPE_ATTR: om.OVAddMetatype})
activation_node_2 = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SIGMOID", GraphPattern.METATYPE_ATTR: om.OVSigmoidMetatype}
)
multiply_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MULTIPLY", GraphPattern.METATYPE_ATTR: om.OVMultiplyMetatype}
)

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern
117 changes: 0 additions & 117 deletions nncf/torch/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,120 +320,3 @@ def create_h_sigmoid_act() -> GraphPattern:
main_pattern.add_pattern_alternative(pattern)

return main_pattern


# pylint:disable=too-many-statements
@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
MEAN_OPERATIONS = {
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: ["avg_pool2d", "adaptive_avg_pool2d", "avg_pool3d", "adaptive_avg_pool3d", "mean"],
}
SYGMOID_OPERATIONS = {
GraphPattern.LABEL_ATTR: "SIGMOID",
GraphPattern.METATYPE_ATTR: ["sigmoid", "hardsigmoid"],
}

def get_se_block_pattern() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

def get_se_block_with_bias_pattern() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

RESHAPE_NODES = {
GraphPattern.LABEL_ATTR: "RESHAPE",
GraphPattern.METATYPE_ATTR: ["reshape", "view", "flatten", "unsqueeze"],
}

def get_se_block_with_reshape() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

def get_se_block_with_bias_and_reshape() -> GraphPattern:
pattern = GraphPattern()
any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS)
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern

main_pattern = GraphPattern()
main_pattern.add_pattern_alternative(get_se_block_pattern())
main_pattern.add_pattern_alternative(get_se_block_with_bias_pattern())
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern
Loading

0 comments on commit d7b342e

Please sign in to comment.