diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 97835f2162a..ca038b77f24 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -545,7 +545,13 @@ class PTSigmoidMetatype(PTOperatorMetatype): class PTAddMetatype(PTOperatorMetatype): name = "AddOp" module_to_function_names = { - NamespaceTarget.TORCH_TENSOR: ["add", "__add__", "__iadd__", "__radd__"], + NamespaceTarget.TORCH_TENSOR: [ + "add", + "add_", + "__add__", + "__iadd__", + "__radd__", + ], NamespaceTarget.TORCH: ["add"], } hw_config_names = [HWConfigOpName.ADD] @@ -556,7 +562,13 @@ class PTAddMetatype(PTOperatorMetatype): class PTSubMetatype(PTOperatorMetatype): name = "SubOp" module_to_function_names = { - NamespaceTarget.TORCH_TENSOR: ["sub", "__sub__", "__isub__", "__rsub__"], + NamespaceTarget.TORCH_TENSOR: [ + "sub", + "sub_", + "__sub__", + "__isub__", + "__rsub__", + ], NamespaceTarget.TORCH: ["sub"], } hw_config_names = [HWConfigOpName.SUBTRACT] @@ -567,7 +579,7 @@ class PTSubMetatype(PTOperatorMetatype): class PTMulMetatype(PTOperatorMetatype): name = "MulOp" module_to_function_names = { - NamespaceTarget.TORCH_TENSOR: ["mul", "__mul__", "__imul__", "__rmul__"], + NamespaceTarget.TORCH_TENSOR: ["mul", "mul_", "__mul__", "__imul__", "__rmul__"], NamespaceTarget.TORCH: ["mul"], } hw_config_names = [HWConfigOpName.MULTIPLY] @@ -580,6 +592,7 @@ class PTDivMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: [ "div", + "div_", "__div__", "__idiv__", "__rdiv__", @@ -691,13 +704,17 @@ class PTThresholdMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + } @PT_OPERATOR_METATYPES.register() class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + } subtypes = [PTModuleBatchNormMetatype] weight_port_ids = [3] bias_port_id = 4 @@ -826,7 +843,7 @@ class PTGatherMetatype(PTOperatorMetatype): name = "GatherOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"], - NamespaceTarget.TORCH: ["gather", "index_select", "where"], + NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"], } @@ -841,7 +858,7 @@ class PTReshapeMetatype(PTOperatorMetatype): name = "ReshapeOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["reshape", "view", "flatten", "unsqueeze"], - NamespaceTarget.TORCH: ["flatten", "unsqueeze"], + NamespaceTarget.TORCH: ["flatten", "unflatten", "unsqueeze"], } hw_config_names = [HWConfigOpName.RESHAPE, HWConfigOpName.UNSQUEEZE, HWConfigOpName.FLATTEN] @@ -1028,7 +1045,9 @@ class PTSqrtMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTInterpolateMetatype(PTOperatorMetatype): name = "InterpolateOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"], + } hw_config_names = [HWConfigOpName.INTERPOLATE] num_expected_input_edges = 1 diff --git a/nncf/torch/graph/pattern_operations.py b/nncf/torch/graph/pattern_operations.py index d9957871d87..1190079a5a6 100644 --- a/nncf/torch/graph/pattern_operations.py +++ b/nncf/torch/graph/pattern_operations.py @@ -10,54 +10,79 @@ # limitations under the License. from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import merge_two_types_of_operations +from nncf.torch.graph import operator_metatypes as om LINEAR_OPERATIONS = { GraphPattern.METATYPE_ATTR: [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - "deform_conv2d", - "addmm", - "bmm", - "matmul", - "mm", - "baddbmm", + # Linear + om.PTLinearMetatype, + om.PTModuleLinearMetatype, + # Conv1D + om.PTConv1dMetatype, + om.PTDepthwiseConv1dSubtype, + om.PTModuleConv1dMetatype, + om.PTModuleDepthwiseConv1dSubtype, + # Conv2D + om.PTConv2dMetatype, + om.PTDepthwiseConv2dSubtype, + om.PTModuleConv2dMetatype, + om.PTModuleDepthwiseConv2dSubtype, + # Conv3D + om.PTConv3dMetatype, + om.PTDepthwiseConv3dSubtype, + om.PTModuleConv3dMetatype, + om.PTModuleDepthwiseConv3dSubtype, + # Transposed conv + om.PTConvTranspose1dMetatype, + om.PTModuleConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTModuleConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + om.PTModuleConvTranspose3dMetatype, + # Deform conv + om.PTDeformConv2dMetatype, + om.PTModuleDeformConv2dMetatype, + # MatMul + om.PTMatMulMetatype, + # Addmm + om.PTAddmmMetatype, ], GraphPattern.LABEL_ATTR: "LINEAR", } BATCH_NORMALIZATION_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["batch_norm", "batch_norm1d", "batch_norm2d", "batch_norm3d"], + GraphPattern.METATYPE_ATTR: [om.PTBatchNormMetatype, om.PTModuleBatchNormMetatype], GraphPattern.LABEL_ATTR: "BATCH_NORMALIZATION", } GROUP_NORMALIZATION_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["group_norm"], + GraphPattern.METATYPE_ATTR: [om.PTGroupNormMetatype, om.PTModuleGroupNormMetatype], GraphPattern.LABEL_ATTR: "GROUP_NORMALIZATION", } LAYER_NORMALIZATION_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["layer_norm"], + GraphPattern.METATYPE_ATTR: [om.PTLayerNormMetatype, om.PTModuleLayerNormMetatype], GraphPattern.LABEL_ATTR: "LAYER_NORMALIZATION", } -RELU_OPERATIONS = {GraphPattern.METATYPE_ATTR: ["relu", "relu_", "hardtanh"], GraphPattern.LABEL_ATTR: "RELU"} +RELU_OPERATIONS = { + GraphPattern.METATYPE_ATTR: [ + om.PTRELUMetatype, + om.PTHardTanhMetatype, + ], + GraphPattern.LABEL_ATTR: "RELU", +} NON_RELU_ACTIVATIONS_OPERATIONS = { GraphPattern.METATYPE_ATTR: [ - "elu", - "elu_", - "prelu", - "leaky_relu", - "sigmoid", - "gelu", - "silu", - "hardsigmoid", - "hardswish", + om.PTELUMetatype, + om.PTPRELUMetatype, + om.PTLeakyRELUMetatype, + om.PTSigmoidMetatype, + om.PTGELUMetatype, + om.PTSILUMetatype, + om.PTHardSigmoidMetatype, + om.PTHardSwishMetatype, ], GraphPattern.LABEL_ATTR: "NON_RELU_ACTIVATIONS", } @@ -67,13 +92,6 @@ ) ARITHMETIC_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"], + GraphPattern.METATYPE_ATTR: [om.PTAddMetatype, om.PTSubMetatype, om.PTMulMetatype, om.PTDivMetatype], GraphPattern.LABEL_ATTR: "ARITHMETIC", } - -# This type may be useful in the future - -POOLING_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["adaptive_avg_pool2d", "adaptive_avg_pool3d", "avg_pool2d", "avg_pool3d"], - GraphPattern.LABEL_ATTR: "POOLING", -} diff --git a/nncf/torch/hardware/fused_patterns.py b/nncf/torch/hardware/fused_patterns.py index 3cff4d1ce98..48c6a38ba40 100644 --- a/nncf/torch/hardware/fused_patterns.py +++ b/nncf/torch/hardware/fused_patterns.py @@ -12,6 +12,7 @@ from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import HWFusedPatternNames from nncf.common.utils.registry import Registry +from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype from nncf.torch.graph.pattern_operations import ARITHMETIC_OPERATIONS from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS @@ -30,12 +31,12 @@ def create_l2_norm_operations() -> GraphPattern: pattern = GraphPattern() outside_pattern_node = pattern.add_node(label="*OUTSIDE_PATTERN_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - pow_node = pattern.add_node(label="POW", type="pow") - sum_node = pattern.add_node(label="SUM", type="sum") - sqrt_node = pattern.add_node(label="SQRT", type="sqrt") - add_node = pattern.add_node(label="ADD", type="__add__") - div_node = pattern.add_node(label="DIV", type="div") - mul_node = pattern.add_node(label="MUL", type="__rmul__") + pow_node = pattern.add_node(label="POW", type=om.PTPowerMetatype) + sum_node = pattern.add_node(label="SUM", type=om.PTSumMetatype) + sqrt_node = pattern.add_node(label="SQRT", type=om.PTSqrtMetatype) + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + div_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) pattern.add_edge(outside_pattern_node, pow_node) pattern.add_edge(pow_node, sum_node) @@ -53,8 +54,8 @@ def create_l2_norm_operations() -> GraphPattern: @PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SHIFT_SCALE) def create_shift_scale() -> GraphPattern: pattern = GraphPattern() - add_node = pattern.add_node(label="ADD, SUB", type=["__add__", "__sub__"]) - truediv_node = pattern.add_node(label="MUL, DIV", type=["__mul__", "__truediv__"]) + add_node = pattern.add_node(label="ADD, SUB", type=[om.PTAddMetatype, om.PTSubMetatype]) + truediv_node = pattern.add_node(label="MUL, DIV", type=[om.PTMulMetatype, om.PTDivMetatype]) pattern.add_edge(add_node, truediv_node) return pattern @@ -62,7 +63,7 @@ def create_shift_scale() -> GraphPattern: @PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.INPUT_SHIFT_SCALE) def create_input_shift_scale() -> GraphPattern: pattern = GraphPattern() - pattern.add_node(**{GraphPattern.LABEL_ATTR: "MODEL_INPUT", GraphPattern.METATYPE_ATTR: PTInputNoopMetatype}) + pattern.add_node(label="MODEL_INPUT", type=PTInputNoopMetatype) shift_scale = create_shift_scale() pattern.join_patterns(shift_scale) return pattern @@ -177,8 +178,8 @@ def create_group_norm_relu_operations() -> GraphPattern: @PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_CONST_MULTIPLY) def create_linear_const_multiply() -> GraphPattern: pattern = GraphPattern() - linear_node = pattern.add_node(label="linear", type="linear") - mul_node = pattern.add_node(label="MUL", type="__mul__") + linear_node = pattern.add_node(label="linear", type=[om.PTLinearMetatype, om.PTModuleLinearMetatype]) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) pattern.add_edge(linear_node, mul_node) return pattern @@ -220,8 +221,8 @@ def activation_operations() -> GraphPattern: def create_swish_act() -> GraphPattern: pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - sigmoid_node = pattern.add_node(label="SIGMOID", type="sigmoid") - mul_node = pattern.add_node(label="MUL", type="__mul__") + sigmoid_node = pattern.add_node(label="SIGMOID", type=om.PTSigmoidMetatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) pattern.add_edge(input_pattern_node, sigmoid_node) pattern.add_edge(sigmoid_node, mul_node) @@ -235,10 +236,10 @@ def create_h_swish_act() -> GraphPattern: # Mul -> Div version pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") - mul_node = pattern.add_node(label="MUL", type="__mul__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(input_pattern_node, mul_node) @@ -250,10 +251,10 @@ def create_h_swish_act() -> GraphPattern: # Div -> Mul version pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh") - mul_node = pattern.add_node(label="MUL", type="__mul__") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(input_pattern_node, mul_node) @@ -265,10 +266,10 @@ def create_h_swish_act() -> GraphPattern: # ReLU6 version - Mul -> Div pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - relu6_node = pattern.add_node(label="RELU6", type="relu6") - mul_node = pattern.add_node(label="MUL", type="__mul__") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(input_pattern_node, mul_node) @@ -280,10 +281,10 @@ def create_h_swish_act() -> GraphPattern: # ReLU6 version - Div -> Mul pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - relu6_node = pattern.add_node(label="RELU6", type="relu6") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") - mul_node = pattern.add_node(label="MUL", type="__mul__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) + mul_node = pattern.add_node(label="MUL", type=om.PTMulMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(input_pattern_node, mul_node) @@ -303,9 +304,9 @@ def create_h_sigmoid_act() -> GraphPattern: pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - hardtanh_node = pattern.add_node(label="HARDTANH", type="hardtanh") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + hardtanh_node = pattern.add_node(label="HARDTANH", type=om.PTHardTanhMetatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, hardtanh_node) @@ -317,9 +318,9 @@ def create_h_sigmoid_act() -> GraphPattern: pattern = GraphPattern() input_pattern_node = pattern.add_node(label="*INPUT_NODE*", type=GraphPattern.NON_PATTERN_NODE_TYPE) - add_node = pattern.add_node(label="ADD", type="__add__") - relu6_node = pattern.add_node(label="RELU6", type="relu6") - truediv_node = pattern.add_node(label="DIV", type="__truediv__") + add_node = pattern.add_node(label="ADD", type=om.PTAddMetatype) + relu6_node = pattern.add_node(label="RELU6", type=om.PTRELU6Metatype) + truediv_node = pattern.add_node(label="DIV", type=om.PTDivMetatype) pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) diff --git a/nncf/torch/quantization/ignored_patterns.py b/nncf/torch/quantization/ignored_patterns.py index b1e6c522ada..895849c244e 100644 --- a/nncf/torch/quantization/ignored_patterns.py +++ b/nncf/torch/quantization/ignored_patterns.py @@ -11,6 +11,7 @@ from nncf.common.graph.patterns.patterns import GraphPattern from nncf.common.graph.patterns.patterns import IgnoredPatternNames from nncf.common.utils.registry import Registry +from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.torch.graph.pattern_operations import LINEAR_OPERATIONS @@ -19,11 +20,11 @@ def _add_softmax_matmul( pattern: GraphPattern, - matmul_aliases, - reshape_squeeze_aliases, - gather_aliases, - transpose_aliases, - concat_aliases, + matmul_metatypes, + reshape_squeeze_metatypes, + gather_metatypes, + transpose_metatypes, + concat_metatypes, ) -> None: # SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT # \ / @@ -32,9 +33,9 @@ def _add_softmax_matmul( # \ / # \ / # MATMUL - branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases + concat_aliases - softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: "softmax"}) - matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_aliases}) + branch_matmul_nodes = reshape_squeeze_metatypes + gather_metatypes + transpose_metatypes + concat_metatypes + softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.PTSoftmaxMetatype}) + matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_metatypes}) matmul_branch_nodes = pattern.add_node( **{GraphPattern.LABEL_ATTR: "NON_PATTERN", GraphPattern.METATYPE_ATTR: branch_matmul_nodes} ) @@ -44,11 +45,11 @@ def _add_softmax_matmul( def _add_softmax_reshape_matmul( pattern: GraphPattern, - matmul_aliases, - reshape_squeeze_aliases, - gather_aliases, - transpose_aliases, - concat_aliases, + matmul_metatypes, + reshape_squeeze_metatypes, + gather_metatypes, + transpose_metatypes, + concat_metatypes, ) -> None: # SOFTMAX # \ @@ -62,12 +63,12 @@ def _add_softmax_reshape_matmul( # \ / # \ / # MATMUL - branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases + concat_aliases - softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: "softmax"}) + branch_matmul_nodes = reshape_squeeze_metatypes + gather_metatypes + transpose_metatypes + concat_metatypes + softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.PTSoftmaxMetatype}) reshape = pattern.add_node( - **{GraphPattern.LABEL_ATTR: "RESHAPE", GraphPattern.METATYPE_ATTR: reshape_squeeze_aliases} + **{GraphPattern.LABEL_ATTR: "RESHAPE", GraphPattern.METATYPE_ATTR: reshape_squeeze_metatypes} ) - matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_aliases}) + matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_metatypes}) matmul_branch_nodes = pattern.add_node( **{GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER", GraphPattern.METATYPE_ATTR: branch_matmul_nodes} ) @@ -79,35 +80,28 @@ def _add_softmax_reshape_matmul( @PT_IGNORED_PATTERNS.register(IgnoredPatternNames.MULTIHEAD_ATTENTION_OUTPUT) def create_multihead_attention_output() -> GraphPattern: - matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm", "__matmul__"] - reshape_squeeze_aliases = [ - "reshape", - "view", - "flatten", - "unsqueeze", - "squeeze", - "unbind", - ] - gather_aliases = ["gather", "index_select", "where", "index_select", "__getitem__"] - transpose_aliases = ["transpose", "permute", "transpose_"] - concat_aliases = ["cat", "stack"] + matmul_metatypes = [om.PTLinearMetatype, om.PTAddmmMetatype, om.PTMatMulMetatype] + reshape_squeeze_metatypes = [om.PTReshapeMetatype, om.PTSqueezeMetatype, om.PTSplitMetatype] + gather_metatypes = [om.PTGatherMetatype] + transpose_metatypes = [om.PTTransposeMetatype] + concat_metatypes = [om.PTCatMetatype] pattern = GraphPattern() _add_softmax_matmul( pattern, - matmul_aliases=matmul_aliases, - reshape_squeeze_aliases=reshape_squeeze_aliases, - gather_aliases=gather_aliases, - transpose_aliases=transpose_aliases, - concat_aliases=concat_aliases, + matmul_metatypes=matmul_metatypes, + reshape_squeeze_metatypes=reshape_squeeze_metatypes, + gather_metatypes=gather_metatypes, + transpose_metatypes=transpose_metatypes, + concat_metatypes=concat_metatypes, ) _add_softmax_reshape_matmul( pattern, - matmul_aliases=matmul_aliases, - reshape_squeeze_aliases=reshape_squeeze_aliases, - gather_aliases=gather_aliases, - transpose_aliases=transpose_aliases, - concat_aliases=concat_aliases, + matmul_metatypes=matmul_metatypes, + reshape_squeeze_metatypes=reshape_squeeze_metatypes, + gather_metatypes=gather_metatypes, + transpose_metatypes=transpose_metatypes, + concat_metatypes=concat_metatypes, ) return pattern @@ -117,16 +111,16 @@ def create_multihead_attention_output() -> GraphPattern: def create_se_block() -> GraphPattern: MEAN_OPERATIONS = { GraphPattern.LABEL_ATTR: "REDUCE_MEAN", - GraphPattern.METATYPE_ATTR: ["avg_pool2d", "adaptive_avg_pool2d", "avg_pool3d", "adaptive_avg_pool3d", "mean"], + GraphPattern.METATYPE_ATTR: [om.PTAvgPool2dMetatype, om.PTAvgPool3dMetatype, om.PTMeanMetatype], GraphPattern.PATTERN_NODE_TO_EXCLUDE: True, } SYGMOID_OPERATIONS = { GraphPattern.LABEL_ATTR: "SIGMOID", - GraphPattern.METATYPE_ATTR: ["sigmoid", "hardsigmoid"], + GraphPattern.METATYPE_ATTR: [om.PTSigmoidMetatype, om.PTHardSigmoidMetatype], } MUL_OPERATION = { GraphPattern.LABEL_ATTR: "MUL", - GraphPattern.METATYPE_ATTR: "__mul__", + GraphPattern.METATYPE_ATTR: om.PTMulMetatype, GraphPattern.PATTERN_NODE_TO_EXCLUDE: True, } @@ -154,10 +148,10 @@ def get_se_block_with_bias_pattern() -> GraphPattern: any_node = pattern.add_node(label="NON_PATTERN_NODE", type=GraphPattern.NON_PATTERN_NODE_TYPE) reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS) linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS) - add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"]) + add_node_1 = pattern.add_node(label="ADD_BIAS", type=[om.PTAddMetatype, om.PTSubMetatype]) activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS) linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS) - add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"]) + add_node_2 = pattern.add_node(label="ADD_BIAS", type=[om.PTAddMetatype, om.PTSubMetatype]) activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS) multiply_node = pattern.add_node(**MUL_OPERATION) @@ -174,7 +168,7 @@ def get_se_block_with_bias_pattern() -> GraphPattern: RESHAPE_NODES = { GraphPattern.LABEL_ATTR: "RESHAPE", - GraphPattern.METATYPE_ATTR: ["reshape", "view", "flatten", "unsqueeze"], + GraphPattern.METATYPE_ATTR: om.PTReshapeMetatype, } def get_se_block_with_reshape() -> GraphPattern: @@ -206,10 +200,10 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern: reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS) reshape_node_1 = pattern.add_node(**RESHAPE_NODES) linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS) - add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"]) + add_node_1 = pattern.add_node(label="ADD_BIAS", type=[om.PTAddMetatype, om.PTSubMetatype]) activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS) linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS) - add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"]) + add_node_2 = pattern.add_node(label="ADD_BIAS", type=[om.PTAddMetatype, om.PTSubMetatype]) activation_node_2 = pattern.add_node(**SYGMOID_OPERATIONS) reshape_node_2 = pattern.add_node(**RESHAPE_NODES) multiply_node = pattern.add_node(**MUL_OPERATION) diff --git a/tests/common/quantization/mock_graphs.py b/tests/common/quantization/mock_graphs.py index b8cb41ee142..ada50e90a21 100644 --- a/tests/common/quantization/mock_graphs.py +++ b/tests/common/quantization/mock_graphs.py @@ -21,6 +21,7 @@ from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.insertion_point_graph import InsertionPointGraph from nncf.common.insertion_point_graph import PostHookInsertionPoint @@ -188,16 +189,26 @@ def get_mock_nncf_node_attrs(op_name=None, scope_str=None, metatype=None, type_= def _add_nodes_with_layer_attrs( - nx_graph: nx.DiGraph, node_keys: List[str], layer_attrs: Dict[str, BaseLayerAttributes] + nx_graph: nx.DiGraph, + node_keys: List[str], + layer_attrs: Dict[str, BaseLayerAttributes], + metatypes: Dict[str, OperatorMetatype] = None, ) -> nx.DiGraph: for node_key in node_keys: - nx_graph.add_node(node_key, **get_mock_nncf_node_attrs(op_name=node_key)) + metatype = None + if metatypes is not None and node_key in metatypes: + metatype = metatypes[node_key] + nx_graph.add_node(node_key, **get_mock_nncf_node_attrs(op_name=node_key, metatype=metatype)) + if node_key in layer_attrs: nx_graph.nodes[node_key][NNCFNode.LAYER_ATTRIBUTES] = layer_attrs[node_key] + return nx_graph -def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph: +def get_mock_model_graph_with_mergeable_pattern( + conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None +) -> NNCFGraph: mock_nx_graph = nx.DiGraph() # (A) @@ -225,7 +236,12 @@ def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph: padding_values=[0, 0, 0, 0], ) } - mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs) + metatypes = { + "conv2d": conv2d_metatype, + "batch_norm": batchnorm_metatype, + "relu": relu_metatype, + } + mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes) mock_nx_graph.add_edges_from( [ @@ -238,7 +254,9 @@ def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph: return get_nncf_graph_from_mock_nx_graph(mock_nx_graph) -def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph: +def get_mock_model_graph_with_no_mergeable_pattern( + conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None +) -> NNCFGraph: mock_nx_graph = nx.DiGraph() # (A) @@ -270,7 +288,12 @@ def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph: padding_values=[0, 0, 0, 0], ) } - mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs) + metatypes = { + "conv2d": conv2d_metatype, + "batch_norm": batchnorm_metatype, + "relu": relu_metatype, + } + mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes) mock_nx_graph.add_edges_from( [ @@ -285,7 +308,9 @@ def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph: return get_nncf_graph_from_mock_nx_graph(mock_nx_graph) -def get_mock_model_graph_with_broken_output_edge_pattern() -> NNCFGraph: +def get_mock_model_graph_with_broken_output_edge_pattern( + conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None +) -> NNCFGraph: mock_nx_graph = nx.DiGraph() # (A) @@ -314,7 +339,12 @@ def get_mock_model_graph_with_broken_output_edge_pattern() -> NNCFGraph: padding_values=[0, 0, 0, 0], ) } - mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs) + metatypes = { + "conv2d": conv2d_metatype, + "batch_norm": batchnorm_metatype, + "relu": relu_metatype, + } + mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes) mock_nx_graph.add_edges_from( [ diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/MHA_single_input.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/MHA_single_input.dot index 578b6e795d7..5f55464d859 100644 --- a/tests/torch/data/reference_graphs/quantized/synthetic_model/MHA_single_input.dot +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/MHA_single_input.dot @@ -2,8 +2,8 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; "1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; "2 MHA_single_input/MultiheadAttention[mha]/linear_0" [id=2, type=linear]; -"3 MHA_single_input/MultiheadAttention[mha]/unflatten_0" [id=3, type=unflatten]; -"4 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0" [id=4, type=symmetric_quantize]; +"3 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0" [id=3, type=symmetric_quantize]; +"4 MHA_single_input/MultiheadAttention[mha]/unflatten_0" [id=4, type=unflatten]; "5 MHA_single_input/MultiheadAttention[mha]/unsqueeze_0" [id=5, type=unsqueeze]; "6 MHA_single_input/MultiheadAttention[mha]/transpose_0" [id=6, type=transpose]; "7 MHA_single_input/MultiheadAttention[mha]/squeeze_0" [id=7, type=squeeze]; @@ -36,9 +36,9 @@ strict digraph { "34 /nncf_model_output_1" [id=34, type=nncf_model_output]; "0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; "1 SymmetricQuantizer/symmetric_quantize_0" -> "2 MHA_single_input/MultiheadAttention[mha]/linear_0"; -"2 MHA_single_input/MultiheadAttention[mha]/linear_0" -> "3 MHA_single_input/MultiheadAttention[mha]/unflatten_0"; -"3 MHA_single_input/MultiheadAttention[mha]/unflatten_0" -> "4 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0"; -"4 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0" -> "5 MHA_single_input/MultiheadAttention[mha]/unsqueeze_0"; +"2 MHA_single_input/MultiheadAttention[mha]/linear_0" -> "3 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0"; +"3 MHA_single_input/MultiheadAttention[mha]/SymmetricQuantizer/symmetric_quantize_0" -> "4 MHA_single_input/MultiheadAttention[mha]/unflatten_0"; +"4 MHA_single_input/MultiheadAttention[mha]/unflatten_0" -> "5 MHA_single_input/MultiheadAttention[mha]/unsqueeze_0"; "5 MHA_single_input/MultiheadAttention[mha]/unsqueeze_0" -> "6 MHA_single_input/MultiheadAttention[mha]/transpose_0"; "6 MHA_single_input/MultiheadAttention[mha]/transpose_0" -> "7 MHA_single_input/MultiheadAttention[mha]/squeeze_0"; "7 MHA_single_input/MultiheadAttention[mha]/squeeze_0" -> "8 MHA_single_input/MultiheadAttention[mha]/contiguous_0"; diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot index 3b5e80a9758..765baeccf24 100644 --- a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__multi_input_branch.dot @@ -1,25 +1,25 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; -"1 ShiftScaleParametrized/clone_0" [id=1, type=clone]; -"2 ShiftScaleParametrized/sub__0" [id=2, type=sub_]; -"3 ShiftScaleParametrized/div__0" [id=3, type=div_]; -"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=4, type=symmetric_quantize]; -"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; -"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=6, type=conv2d]; -"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=7, type=symmetric_quantize]; -"8 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" [id=8, type=symmetric_quantize]; +"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; +"2 ShiftScaleParametrized/clone_0" [id=2, type=clone]; +"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_]; +"4 ShiftScaleParametrized/div__0" [id=4, type=div_]; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize]; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d]; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" [id=8, type=symmetric_quantize]; "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" [id=9, type=conv2d]; "10 /nncf_model_output_0" [id=10, type=nncf_model_output]; "11 /nncf_model_output_1" [id=11, type=nncf_model_output]; -"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/clone_0"; -"0 /nncf_model_input_0" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0"; -"1 ShiftScaleParametrized/clone_0" -> "2 ShiftScaleParametrized/sub__0"; -"2 ShiftScaleParametrized/sub__0" -> "3 ShiftScaleParametrized/div__0"; -"3 ShiftScaleParametrized/div__0" -> "4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; -"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; -"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; -"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "10 /nncf_model_output_0"; -"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; -"8 ShiftScaleParametrized/NNCFConv2d[conv]/SymmetricQuantizer/symmetric_quantize_0" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ShiftScaleParametrized/clone_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; +"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0"; +"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0"; +"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "10 /nncf_model_output_0"; +"8 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_1" -> "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1"; "9 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_1" -> "11 /nncf_model_output_1"; } diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot index 9eab740c541..793bfa2f7da 100644 --- a/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ShiftScale__normalize__single_input_branch.dot @@ -1,17 +1,19 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; -"1 ShiftScaleParametrized/clone_0" [id=1, type=clone]; -"2 ShiftScaleParametrized/sub__0" [id=2, type=sub_]; -"3 ShiftScaleParametrized/div__0" [id=3, type=div_]; -"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=4, type=symmetric_quantize]; -"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; -"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=6, type=conv2d]; -"7 /nncf_model_output_0" [id=7, type=nncf_model_output]; -"0 /nncf_model_input_0" -> "1 ShiftScaleParametrized/clone_0"; -"1 ShiftScaleParametrized/clone_0" -> "2 ShiftScaleParametrized/sub__0"; -"2 ShiftScaleParametrized/sub__0" -> "3 ShiftScaleParametrized/div__0"; -"3 ShiftScaleParametrized/div__0" -> "4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; -"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; -"5 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; -"6 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "7 /nncf_model_output_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; +"2 ShiftScaleParametrized/clone_0" [id=2, type=clone]; +"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_]; +"4 ShiftScaleParametrized/div__0" [id=4, type=div_]; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize]; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d]; +"8 /nncf_model_output_0" [id=8, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ShiftScaleParametrized/clone_0"; +"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0"; +"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0"; +"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0"; +"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0"; +"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "8 /nncf_model_output_0"; } diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index d4fe6ca2f72..3eabe1f7e99 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -12,6 +12,7 @@ from collections import Counter from copy import deepcopy from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import List @@ -21,6 +22,7 @@ import torch.nn.functional as F from torch import nn +import nncf.torch.graph.operator_metatypes as om from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME from nncf.common.graph.patterns.manager import PatternsManager @@ -367,9 +369,33 @@ def test_priority(self, target_type, trace_parameters, priority_type): MERGE_PATTERN_TEST_CASES = ( - [get_mock_model_graph_with_mergeable_pattern, "basic_pattern"], - [get_mock_model_graph_with_no_mergeable_pattern, "no_pattern"], - [get_mock_model_graph_with_broken_output_edge_pattern, "broken_output_edges_pattern"], + [ + partial( + get_mock_model_graph_with_mergeable_pattern, + conv2d_metatype=om.PTConv2dMetatype, + batchnorm_metatype=om.PTBatchNormMetatype, + relu_metatype=om.PTRELUMetatype, + ), + "basic_pattern", + ], + [ + partial( + get_mock_model_graph_with_no_mergeable_pattern, + conv2d_metatype=om.PTConv2dMetatype, + batchnorm_metatype=om.PTBatchNormMetatype, + relu_metatype=om.PTRELUMetatype, + ), + "no_pattern", + ], + [ + partial( + get_mock_model_graph_with_broken_output_edge_pattern, + conv2d_metatype=om.PTConv2dMetatype, + batchnorm_metatype=om.PTBatchNormMetatype, + relu_metatype=om.PTRELUMetatype, + ), + "broken_output_edges_pattern", + ], )