Skip to content

Commit

Permalink
add linear_squeeze_arithmetical_activations pattern and ONNXGlobalMax…
Browse files Browse the repository at this point in the history
…PoolMetatype metatype for ONNX backend
  • Loading branch information
alexsu52 committed Sep 18, 2023
1 parent 2b14d25 commit ccded55
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 0 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ class HWFusedPatternNames(Enum):
LINEAR_SCALE_SHIFT_ACTIVATIONS = PatternDesc("linear_scale_shift_activations")
LINEAR_CONST_MULTIPLY = PatternDesc("linear_const_multiply")
LINEAR_SQUEEZE_ACTIVATIONS = PatternDesc("linear_squeeze_activations")
LINEAR_SQUEEZE_ARITHMETIC_ACTIVATIONS = PatternDesc("linear_squeeze_arithmetic_activations")
LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE = PatternDesc("linear_activations_unsqueeze_bn_squeeze")
SCALE_SHIFT_ACTIVATIONS = PatternDesc("scale_shift_activations")
MVN_SCALE_SHIFT_ACTIVATIONS = PatternDesc("mvn_scale_shift_activations")
Expand Down
1 change: 1 addition & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nncf.onnx.graph.metatypes import onnx_metatypes

QUANTIZE_AGNOSTIC_OPERATIONS = [
onnx_metatypes.ONNXGlobalMaxPoolMetatype,
onnx_metatypes.ONNXMaxPoolMetatype,
onnx_metatypes.ONNXReduceMaxMetatype,
onnx_metatypes.ONNXReshapeMetatype,
Expand Down
7 changes: 7 additions & 0 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ class ONNXAveragePoolMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.AVGPOOL]


@ONNX_OPERATION_METATYPES.register()
class ONNXGlobalMaxPoolMetatype(ONNXOpMetatype):
name = "GlobalMaxPoolOp"
op_names = ["GlobalMaxPool"]
hw_config_names = [HWConfigOpName.MAXPOOL]


@ONNX_OPERATION_METATYPES.register()
class ONNXMaxPoolMetatype(ONNXOpMetatype):
name = "MaxPoolOp"
Expand Down
11 changes: 11 additions & 0 deletions nncf/onnx/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,17 @@ def create_linear_squeeze_activation() -> GraphPattern:
return linear


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_SQUEEZE_ARITHMETIC_ACTIVATIONS)
def create_linear_squeeze_arithmetic_activation() -> GraphPattern:
linear = linear_operations()
squeeze = squeeze_operation()
arithmetic_activations = create_arithmetic_activations()

linear.join_patterns(squeeze)
linear.join_patterns(arithmetic_activations)
return linear


@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.BATCH_NORM_SCALE_SHIFT_ACTIVATIONS)
def create_bn_scale_shift_activation() -> GraphPattern:
batch_norm = batch_normalization_operations()
Expand Down
11 changes: 11 additions & 0 deletions nncf/openvino/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,17 @@ def create_linear_squeeze_activation() -> GraphPattern:
return linear


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_SQUEEZE_ARITHMETIC_ACTIVATIONS)
def create_linear_squeeze_activation() -> GraphPattern:
linear = linear_operations()
squeeze = squeeze_operation()
arithmetic_activations = create_arithmetic_activations()

linear.join_patterns(squeeze)
linear.join_patterns(arithmetic_activations)
return linear


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.MVN_SCALE_SHIFT_ACTIVATIONS)
def create_mvn_scale_shift_activations() -> GraphPattern:
pattern = GraphPattern()
Expand Down
1 change: 1 addition & 0 deletions tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
HWFusedPatternNames.LINEAR_BIASED_ACTIVATION_ELEMENTWISE: "Not relevant for Torch.",
HWFusedPatternNames.MVN_SCALE_SHIFT_ACTIVATIONS: "Not relevant for Torch.",
HWFusedPatternNames.LINEAR_SQUEEZE_ACTIVATIONS: "Not relevant for Torch.",
HWFusedPatternNames.LINEAR_SQUEEZE_ARITHMETIC_ACTIVATIONS: "Not relevant for Torch.",
HWFusedPatternNames.LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE: "Not relevant for Torch.",
}

Expand Down

0 comments on commit ccded55

Please sign in to comment.