Skip to content

Commit

Permalink
[Torch] Use metatypes instead of types in pattern matching (#2787)
Browse files Browse the repository at this point in the history
### Changes

* ~~GraphPattern type/metatype attrs are renamed~~
* `Metatypes` are used to identify type of a node instead of `type`
* Elementwise metatypes updated with inplace types

### Reason for changes

As number of types is going to grow (due to new TorchFX backend), it is
easier to specify metatype one time instead of specifying types, which
should be updated each time new type is arrived.

### Related tickets

145981

### Tests

<!--- How was the correctness of changes tested and whether new tests
were added -->
  • Loading branch information
daniil-lyakhov authored Jul 23, 2024
1 parent f981d53 commit 0e1c83a
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 172 deletions.
35 changes: 27 additions & 8 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,13 @@ class PTSigmoidMetatype(PTOperatorMetatype):
class PTAddMetatype(PTOperatorMetatype):
name = "AddOp"
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: ["add", "__add__", "__iadd__", "__radd__"],
NamespaceTarget.TORCH_TENSOR: [
"add",
"add_",
"__add__",
"__iadd__",
"__radd__",
],
NamespaceTarget.TORCH: ["add"],
}
hw_config_names = [HWConfigOpName.ADD]
Expand All @@ -556,7 +562,13 @@ class PTAddMetatype(PTOperatorMetatype):
class PTSubMetatype(PTOperatorMetatype):
name = "SubOp"
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: ["sub", "__sub__", "__isub__", "__rsub__"],
NamespaceTarget.TORCH_TENSOR: [
"sub",
"sub_",
"__sub__",
"__isub__",
"__rsub__",
],
NamespaceTarget.TORCH: ["sub"],
}
hw_config_names = [HWConfigOpName.SUBTRACT]
Expand All @@ -567,7 +579,7 @@ class PTSubMetatype(PTOperatorMetatype):
class PTMulMetatype(PTOperatorMetatype):
name = "MulOp"
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: ["mul", "__mul__", "__imul__", "__rmul__"],
NamespaceTarget.TORCH_TENSOR: ["mul", "mul_", "__mul__", "__imul__", "__rmul__"],
NamespaceTarget.TORCH: ["mul"],
}
hw_config_names = [HWConfigOpName.MULTIPLY]
Expand All @@ -580,6 +592,7 @@ class PTDivMetatype(PTOperatorMetatype):
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: [
"div",
"div_",
"__div__",
"__idiv__",
"__rdiv__",
Expand Down Expand Up @@ -691,13 +704,17 @@ class PTThresholdMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register(is_subtype=True)
class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
}


@PT_OPERATOR_METATYPES.register()
class PTBatchNormMetatype(PTOperatorMetatype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
}
subtypes = [PTModuleBatchNormMetatype]
weight_port_ids = [3]
bias_port_id = 4
Expand Down Expand Up @@ -826,7 +843,7 @@ class PTGatherMetatype(PTOperatorMetatype):
name = "GatherOp"
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"],
NamespaceTarget.TORCH: ["gather", "index_select", "where"],
NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"],
}


Expand All @@ -841,7 +858,7 @@ class PTReshapeMetatype(PTOperatorMetatype):
name = "ReshapeOp"
module_to_function_names = {
NamespaceTarget.TORCH_TENSOR: ["reshape", "view", "flatten", "unsqueeze"],
NamespaceTarget.TORCH: ["flatten", "unsqueeze"],
NamespaceTarget.TORCH: ["flatten", "unflatten", "unsqueeze"],
}
hw_config_names = [HWConfigOpName.RESHAPE, HWConfigOpName.UNSQUEEZE, HWConfigOpName.FLATTEN]

Expand Down Expand Up @@ -1028,7 +1045,9 @@ class PTSqrtMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register()
class PTInterpolateMetatype(PTOperatorMetatype):
name = "InterpolateOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]}
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"],
}
hw_config_names = [HWConfigOpName.INTERPOLATE]
num_expected_input_edges = 1

Expand Down
86 changes: 52 additions & 34 deletions nncf/torch/graph/pattern_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,79 @@
# limitations under the License.
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns import merge_two_types_of_operations
from nncf.torch.graph import operator_metatypes as om

LINEAR_OPERATIONS = {
GraphPattern.METATYPE_ATTR: [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"deform_conv2d",
"addmm",
"bmm",
"matmul",
"mm",
"baddbmm",
# Linear
om.PTLinearMetatype,
om.PTModuleLinearMetatype,
# Conv1D
om.PTConv1dMetatype,
om.PTDepthwiseConv1dSubtype,
om.PTModuleConv1dMetatype,
om.PTModuleDepthwiseConv1dSubtype,
# Conv2D
om.PTConv2dMetatype,
om.PTDepthwiseConv2dSubtype,
om.PTModuleConv2dMetatype,
om.PTModuleDepthwiseConv2dSubtype,
# Conv3D
om.PTConv3dMetatype,
om.PTDepthwiseConv3dSubtype,
om.PTModuleConv3dMetatype,
om.PTModuleDepthwiseConv3dSubtype,
# Transposed conv
om.PTConvTranspose1dMetatype,
om.PTModuleConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTModuleConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
om.PTModuleConvTranspose3dMetatype,
# Deform conv
om.PTDeformConv2dMetatype,
om.PTModuleDeformConv2dMetatype,
# MatMul
om.PTMatMulMetatype,
# Addmm
om.PTAddmmMetatype,
],
GraphPattern.LABEL_ATTR: "LINEAR",
}

BATCH_NORMALIZATION_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["batch_norm", "batch_norm1d", "batch_norm2d", "batch_norm3d"],
GraphPattern.METATYPE_ATTR: [om.PTBatchNormMetatype, om.PTModuleBatchNormMetatype],
GraphPattern.LABEL_ATTR: "BATCH_NORMALIZATION",
}

GROUP_NORMALIZATION_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["group_norm"],
GraphPattern.METATYPE_ATTR: [om.PTGroupNormMetatype, om.PTModuleGroupNormMetatype],
GraphPattern.LABEL_ATTR: "GROUP_NORMALIZATION",
}

LAYER_NORMALIZATION_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["layer_norm"],
GraphPattern.METATYPE_ATTR: [om.PTLayerNormMetatype, om.PTModuleLayerNormMetatype],
GraphPattern.LABEL_ATTR: "LAYER_NORMALIZATION",
}

RELU_OPERATIONS = {GraphPattern.METATYPE_ATTR: ["relu", "relu_", "hardtanh"], GraphPattern.LABEL_ATTR: "RELU"}
RELU_OPERATIONS = {
GraphPattern.METATYPE_ATTR: [
om.PTRELUMetatype,
om.PTHardTanhMetatype,
],
GraphPattern.LABEL_ATTR: "RELU",
}

NON_RELU_ACTIVATIONS_OPERATIONS = {
GraphPattern.METATYPE_ATTR: [
"elu",
"elu_",
"prelu",
"leaky_relu",
"sigmoid",
"gelu",
"silu",
"hardsigmoid",
"hardswish",
om.PTELUMetatype,
om.PTPRELUMetatype,
om.PTLeakyRELUMetatype,
om.PTSigmoidMetatype,
om.PTGELUMetatype,
om.PTSILUMetatype,
om.PTHardSigmoidMetatype,
om.PTHardSwishMetatype,
],
GraphPattern.LABEL_ATTR: "NON_RELU_ACTIVATIONS",
}
Expand All @@ -67,13 +92,6 @@
)

ARITHMETIC_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"],
GraphPattern.METATYPE_ATTR: [om.PTAddMetatype, om.PTSubMetatype, om.PTMulMetatype, om.PTDivMetatype],
GraphPattern.LABEL_ATTR: "ARITHMETIC",
}

# This type may be useful in the future

POOLING_OPERATIONS = {
GraphPattern.METATYPE_ATTR: ["adaptive_avg_pool2d", "adaptive_avg_pool3d", "avg_pool2d", "avg_pool3d"],
GraphPattern.LABEL_ATTR: "POOLING",
}
71 changes: 36 additions & 35 deletions nncf/torch/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns import HWFusedPatternNames
from nncf.common.utils.registry import Registry
from nncf.torch.graph import operator_metatypes as om
from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype
from nncf.torch.graph.pattern_operations import ARITHMETIC_OPERATIONS
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
Expand All @@ -30,12 +31,12 @@ def create_l2_norm_operations() -> GraphPattern:
pattern = GraphPattern()

outside_pattern_node = pattern.add_node(label="*OUTSIDE_PATTERN_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
pow_node = pattern.add_node(label="POW", type="pow")
sum_node = pattern.add_node(label="SUM", type="sum")
sqrt_node = pattern.add_node(label="SQRT", type="sqrt")
add_node = pattern.add_node(label="ADD", type="__add__")
div_node = pattern.add_node(label="DIV", type="div")
mul_node = pattern.add_node(label="MUL", type="__rmul__")
pow_node = pattern.add_node(label="POW", type=om.PTPowerMetatype)
sum_node = pattern.add_node(label="SUM", type=om.PTSumMetatype)
sqrt_node = pattern.add_node(label="SQRT", type=om.PTSqrtMetatype)
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
div_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)

pattern.add_edge(outside_pattern_node, pow_node)
pattern.add_edge(pow_node, sum_node)
Expand All @@ -53,16 +54,16 @@ def create_l2_norm_operations() -> GraphPattern:
@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE)
def create_shift_scale() -> GraphPattern:
pattern = GraphPattern()
add_node = pattern.add_node(label="ADD, SUB", type=["__add__", "__sub__"])
truediv_node = pattern.add_node(label="MUL, DIV", type=["__mul__", "__truediv__"])
add_node = pattern.add_node(label="ADD, SUB", type=[om.PTAddMetatype, om.PTSubMetatype])
truediv_node = pattern.add_node(label="MUL, DIV", type=[om.PTMulMetatype, om.PTDivMetatype])
pattern.add_edge(add_node, truediv_node)
return pattern


@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE)
def create_input_shift_scale() -> GraphPattern:
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: PTInputNoopMetatype})
pattern.add_node(label="MODEL_INPUT", type=PTInputNoopMetatype)
shift_scale = create_shift_scale()
pattern.join_patterns(shift_scale)
return pattern
Expand Down Expand Up @@ -177,8 +178,8 @@ def create_group_norm_relu_operations() -> GraphPattern:
@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_CONST_MULTIPLY)
def create_linear_const_multiply() -> GraphPattern:
pattern = GraphPattern()
linear_node = pattern.add_node(label="linear", type="linear")
mul_node = pattern.add_node(label="MUL", type="__mul__")
linear_node = pattern.add_node(label="linear", type=[om.PTLinearMetatype, om.PTModuleLinearMetatype])
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)
pattern.add_edge(linear_node, mul_node)

return pattern
Expand Down Expand Up @@ -220,8 +221,8 @@ def activation_operations() -> GraphPattern:
def create_swish_act() -> GraphPattern:
pattern = GraphPattern()
input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
sigmoid_node = pattern.add_node(label="SIGMOID", type="sigmoid")
mul_node = pattern.add_node(label="MUL", type="__mul__")
sigmoid_node = pattern.add_node(label="SIGMOID", type=om.PTSigmoidMetatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)

pattern.add_edge(input_pattern_node, sigmoid_node)
pattern.add_edge(sigmoid_node, mul_node)
Expand All @@ -235,10 +236,10 @@ def create_h_swish_act() -> GraphPattern:
# Mul -> Div version
pattern = GraphPattern()
input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
mul_node = pattern.add_node(label="MUL", type="__mul__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(input_pattern_node, mul_node)
Expand All @@ -250,10 +251,10 @@ def create_h_swish_act() -> GraphPattern:
# Div -> Mul version
pattern = GraphPattern()
input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh")
mul_node = pattern.add_node(label="MUL", type="__mul__")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(input_pattern_node, mul_node)
Expand All @@ -265,10 +266,10 @@ def create_h_swish_act() -> GraphPattern:
# ReLU6 version - Mul -> Div
pattern = GraphPattern()
input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
relu6_node = pattern.add_node(label="RELU6", type="relu6")
mul_node = pattern.add_node(label="MUL", type="__mul__")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(input_pattern_node, mul_node)
Expand All @@ -280,10 +281,10 @@ def create_h_swish_act() -> GraphPattern:
# ReLU6 version - Div -> Mul
pattern = GraphPattern()
input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
relu6_node = pattern.add_node(label="RELU6", type="relu6")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
mul_node = pattern.add_node(label="MUL", type="__mul__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)
mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(input_pattern_node, mul_node)
Expand All @@ -303,9 +304,9 @@ def create_h_sigmoid_act() -> GraphPattern:
pattern = GraphPattern()

input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(add_node, hardtanh_node)
Expand All @@ -317,9 +318,9 @@ def create_h_sigmoid_act() -> GraphPattern:
pattern = GraphPattern()

input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE)
add_node = pattern.add_node(label="ADD", type="__add__")
relu6_node = pattern.add_node(label="RELU6", type="relu6")
truediv_node = pattern.add_node(label="DIV", type="__truediv__")
add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype)
relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype)
truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype)

pattern.add_edge(input_pattern_node, add_node)
pattern.add_edge(add_node, relu6_node)
Expand Down
Loading

0 comments on commit 0e1c83a

Please sign in to comment.