From da5217a88ec6d1b2c3b4b4ca099311f6682fddb5 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 27 Sep 2024 19:54:20 +0200 Subject: [PATCH] [FX/PT/ONNX][MinMax] Weights nodes/ constant matmuls collection fix (#2944) ### Changes * Weights nodes collection is updated for PT/FX backends * Constant matmul check is updated for PT/FX/ONNX backends ### Reason for changes To unify self attention quantization across all backends and align with the OV quantization ### Related tickets #2766 ### Tests * test_model_type_transformer_quantization_config * post_training_quantization/485/ + post_training_quantization/486/ is finished successfully --- .../algorithms/min_max/algorithm.py | 2 +- .../algorithms/min_max/backend.py | 13 +- .../algorithms/min_max/onnx_backend.py | 6 +- .../algorithms/min_max/openvino_backend.py | 6 +- .../algorithms/min_max/torch_backend.py | 14 ++- .../algorithms/min_max/torch_fx_backend.py | 14 ++- tests/cross_fw/test_templates/models.py | 49 ++++++++ .../test_templates/test_quantizer_config.py | 112 ++++++++++++++---- .../quantization/test_quantizer_config.py | 23 ++++ .../quantization/test_quantizer_config.py | 21 ++++ .../data/ptq_reference_data.yaml | 2 +- tests/torch/fx/test_quantizer_config.py | 16 +++ tests/torch/ptq/test_quantizer_config.py | 18 +++ 13 files changed, 261 insertions(+), 35 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 976e0e34c85..6e0d854df3c 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -1026,7 +1026,7 @@ def _apply_model_type_pass( continue if ( quantization_point.qconfig.mode != QuantizationScheme.SYMMETRIC - and node.layer_attributes is None + and not self._backend_entity.is_matmul_with_constant(node, nncf_graph) ): quantization_point.qconfig.mode = QuantizationScheme.SYMMETRIC nncf_logger.debug( diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 0d93149c635..de562139307 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -303,12 +303,21 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: :return: List of ignored names. """ - @staticmethod @abstractmethod - def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: """ Returns nodes that have weights. :param nncf_graph: Instance of NNCFGraph. :return: All nodes with weights. """ + + @abstractmethod + def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + """ + Returns true if given nncf matmul node is a matmul with a constant, False otherwise. + + :param Node: Instance of NNCFNode. + :param nncf_graph: Instance of NNCFGraph. + :return: True if given nncf matmul node is a matmul with a constant, False otherwise. + """ diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 9cfc257ad8e..a09ca2b7861 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -255,8 +255,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: return set() - @staticmethod - def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: return [node for node in nncf_graph.get_all_nodes() if node.layer_attributes.has_weight()] @staticmethod @@ -268,3 +267,6 @@ def get_weight_name(nncf_graph: NNCFGraph, target_point: ONNXTargetPoint) -> str def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: # If the nodes share one weight tensor, we should have only one quantizer on that return weight_name not in quantized_weight_names + + def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + return node.metatype in self.mat_mul_metatypes and node.layer_attributes.has_weight() diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index dfe144c1ad0..e1db64346bc 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -257,14 +257,16 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: ignored_names.add(node.node_name) return ignored_names - @staticmethod - def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: return [ node for node in nncf_graph.get_all_nodes() if isinstance(node.layer_attributes, OVLayerAttributes) and node.metatype in OPERATIONS_WITH_WEIGHTS ] + def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + return node.metatype in self.mat_mul_metatypes and node.layer_attributes is not None + @staticmethod def get_weight_name(nncf_graph: NNCFGraph, target_point: OVTargetPoint) -> str: node = nncf_graph.get_node_by_name(target_point.target_node_name) diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index ee9329c030e..92b15c40708 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -370,10 +370,18 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: return set() - @staticmethod - def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: - return [ + def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: + weight_nodes_candidates = [ node for node in nncf_graph.get_all_nodes() if issubclass(node.metatype, om.PTOperatorMetatype) and node.metatype.weight_port_ids ] + weight_nodes = [] + for node in weight_nodes_candidates: + if node.metatype in self.mat_mul_metatypes and not self.is_matmul_with_constant(node, nncf_graph): + continue + weight_nodes.append(node) + return weight_nodes + + def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0 diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 7c8f008cb9c..5a170f5bef1 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -346,10 +346,18 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: return set() - @staticmethod - def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: - return [ + def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: + weight_nodes_candidates = [ node for node in nncf_graph.get_all_nodes() if issubclass(node.metatype, om.PTOperatorMetatype) and node.metatype.weight_port_ids ] + weight_nodes = [] + for node in weight_nodes_candidates: + if node.metatype in self.mat_mul_metatypes and not self.is_matmul_with_constant(node, nncf_graph): + continue + weight_nodes.append(node) + return weight_nodes + + def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0 diff --git a/tests/cross_fw/test_templates/models.py b/tests/cross_fw/test_templates/models.py index 6cc2f0d5c38..c012e93443a 100644 --- a/tests/cross_fw/test_templates/models.py +++ b/tests/cross_fw/test_templates/models.py @@ -412,3 +412,52 @@ def __init__( original_mock_graph = create_mock_graph(nodes, edges) self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls) + + +class NNCFGraphTransformer: + def __init__( + self, + matmul_metatype, + softmax_metatype, + transpose_metatype, + const_metatype, + mul_metatype, + matmul_layer_weighted_attrs=None, + matmul_layer_non_weighted_attrs=None, + default_layer_attrs=None, + nncf_graph_cls=NNCFGraph, + ): + # softmax((K x Q) * scale) x V.T + nodes = [ + NodeWithType("Input_1", InputNoopMetatype, layer_attributes=default_layer_attrs), + NodeWithType("W_K", const_metatype, layer_attributes=default_layer_attrs), + NodeWithType("W_Q", const_metatype, layer_attributes=default_layer_attrs), + NodeWithType("W_V", const_metatype, layer_attributes=default_layer_attrs), + NodeWithType("K", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs), + NodeWithType("Q", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs), + NodeWithType("V", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs), + NodeWithType("K_Q", matmul_metatype, layer_attributes=matmul_layer_non_weighted_attrs), + NodeWithType("div", mul_metatype, layer_attributes=default_layer_attrs), + NodeWithType("softmax", softmax_metatype, layer_attributes=default_layer_attrs), + NodeWithType("transpose", transpose_metatype, layer_attributes=default_layer_attrs), + NodeWithType("SA_V", matmul_metatype, layer_attributes=matmul_layer_non_weighted_attrs), + NodeWithType("Output_1", OutputNoopMetatype, layer_attributes=default_layer_attrs), + ] + node_edges = [ + ("Input_1", "K"), + ("W_K", "K"), + ("Input_1", "Q"), + ("W_Q", "Q"), + ("Input_1", "V"), + ("W_V", "V"), + ("K", "K_Q"), + ("Q", "K_Q"), + ("K_Q", "div"), + ("div", "softmax"), + ("softmax", "SA_V"), + ("V", "transpose"), + ("transpose", "SA_V"), + ("SA_V", "Output_1"), + ] + original_mock_graph = create_mock_graph(nodes, node_edges) + self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls) diff --git a/tests/cross_fw/test_templates/test_quantizer_config.py b/tests/cross_fw/test_templates/test_quantizer_config.py index a01e9ffdb8e..618007652f9 100644 --- a/tests/cross_fw/test_templates/test_quantizer_config.py +++ b/tests/cross_fw/test_templates/test_quantizer_config.py @@ -16,7 +16,10 @@ import pytest +from nncf import ModelType +from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.patterns import GraphPattern +from nncf.common.graph.patterns.manager import PatternsManager from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint @@ -34,6 +37,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase +from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import QuantizationParameters from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.passes import transform_to_inference_graph @@ -48,6 +52,10 @@ class TemplateTestQuantizerConfig: def get_algo_backend(self): pass + @abstractmethod + def get_backend_type(self): + pass + def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector): aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] assert len(aggrs) == 2 @@ -63,11 +71,26 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl def get_reduction_axes(self, reducer: TensorReducerBase) -> ReductionAxes: return reducer._reduction_axes + @staticmethod + def _transform_to_inference_graph(nncf_graph: NNCFGraph, min_max_algo: MinMaxQuantization): + return transform_to_inference_graph( + deepcopy(nncf_graph), + min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), + min_max_algo._backend_entity.shapeof_metatypes, + min_max_algo._backend_entity.dropout_metatypes, + min_max_algo._backend_entity.preserved_metatypes, + ) + @abstractmethod @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: pass + @abstractmethod + @pytest.fixture + def transformer_nncf_graph(self) -> NNCFGraphToTest: + pass + @abstractmethod @pytest.fixture def depthwise_conv_nncf_graph(self) -> NNCFGraphToTestDepthwiseConv: @@ -129,13 +152,7 @@ def test_default_quantizer_config(self, single_conv_nncf_graph): min_max_algo = MinMaxQuantization() min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = single_conv_nncf_graph.nncf_graph - inference_nncf_graph = transform_to_inference_graph( - deepcopy(nncf_graph), - min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), - min_max_algo._backend_entity.shapeof_metatypes, - min_max_algo._backend_entity.dropout_metatypes, - min_max_algo._backend_entity.preserved_metatypes, - ) + inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo) q_setup = min_max_algo._get_quantizer_setup( nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() ) @@ -184,13 +201,7 @@ def test_quantizer_config_from_ptq_params_for_CPU( ) min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = single_conv_nncf_graph.nncf_graph - inference_nncf_graph = transform_to_inference_graph( - deepcopy(nncf_graph), - min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), - min_max_algo._backend_entity.shapeof_metatypes, - min_max_algo._backend_entity.dropout_metatypes, - min_max_algo._backend_entity.preserved_metatypes, - ) + inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo) if signed_weights is False or signed_activations in [True, False]: # Incompatible with HW CPU config with pytest.raises( ValueError, @@ -227,13 +238,7 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph min_max_algo = MinMaxQuantization() min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = depthwise_conv_nncf_graph.nncf_graph - inference_nncf_graph = transform_to_inference_graph( - deepcopy(nncf_graph), - min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph), - min_max_algo._backend_entity.shapeof_metatypes, - min_max_algo._backend_entity.dropout_metatypes, - min_max_algo._backend_entity.preserved_metatypes, - ) + inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo) q_setup = min_max_algo._get_quantizer_setup( nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() ) @@ -253,6 +258,71 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph if quantization_point.is_activation_quantization_point(): assert quantization_point.qconfig == activation_default_config + REF_TRANSFORMER_SETUP_STATE = { + "quantization_points": { + 4: { + "qip": {"target_node_name": "/K_0", "input_port_id": None}, + "qip_class": "ActivationQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": None, "per_channel": False}, + "directly_quantized_operator_node_names": ["/K_Q_0"], + }, + 5: { + "qip": {"target_node_name": "/Q_0", "input_port_id": None}, + "qip_class": "ActivationQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": None, "per_channel": False}, + "directly_quantized_operator_node_names": ["/K_Q_0"], + }, + 6: { + "qip": {"target_node_name": "/Input_1_0", "input_port_id": None}, + "qip_class": "ActivationQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "asymmetric", "signedness_to_force": None, "per_channel": False}, + "directly_quantized_operator_node_names": ["/K_0", "/Q_0", "/V_0"], + }, + 8: { + "qip": {"target_node_name": "/K_0"}, + "qip_class": "WeightQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True}, + "directly_quantized_operator_node_names": ["/K_0"], + }, + 9: { + "qip": {"target_node_name": "/Q_0"}, + "qip_class": "WeightQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True}, + "directly_quantized_operator_node_names": ["/Q_0"], + }, + 10: { + "qip": {"target_node_name": "/V_0"}, + "qip_class": "WeightQuantizationInsertionPoint", + "qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True}, + "directly_quantized_operator_node_names": ["/V_0"], + }, + }, + "unified_scale_groups": {}, + "shared_input_operation_set_groups": {0: [4, 5], 1: [8, 9, 10, 6]}, + } + + def test_model_type_transformer_quantization_config(self, transformer_nncf_graph): + min_max_algo = MinMaxQuantization(model_type=ModelType.TRANSFORMER) + min_max_algo._backend_entity = self.get_algo_backend() + nncf_graph = transformer_nncf_graph.nncf_graph + inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo) + hw_patterns = PatternsManager.get_full_hw_pattern_graph( + backend=self.get_backend_type(), device=TargetDevice.ANY, model_type=ModelType.TRANSFORMER + ) + ignored_patterns = PatternsManager.get_full_ignored_pattern_graph( + backend=self.get_backend_type(), device=TargetDevice.ANY, model_type=ModelType.TRANSFORMER + ) + q_setup = min_max_algo._get_quantizer_setup( + nncf_graph, inference_nncf_graph, hw_patterns=hw_patterns, ignored_patterns=ignored_patterns + ) + min_max_algo._apply_model_type_pass(ModelType.TRANSFORMER, q_setup, nncf_graph) + + state = q_setup.get_state() + state["quantization_points"][6]["directly_quantized_operator_node_names"] = sorted( + state["quantization_points"][6]["directly_quantized_operator_node_names"] + ) + assert state == self.REF_TRANSFORMER_SETUP_STATE + @pytest.mark.parametrize( "range_estimator_params", [RangeEstimatorParametersSet.MINMAX, RangeEstimatorParametersSet.MEAN_MINMAX] ) diff --git a/tests/onnx/quantization/test_quantizer_config.py b/tests/onnx/quantization/test_quantizer_config.py index b82318b34db..1ff5563f30f 100644 --- a/tests/onnx/quantization/test_quantizer_config.py +++ b/tests/onnx/quantization/test_quantizer_config.py @@ -11,14 +11,21 @@ import pytest +from nncf.common.utils.backend import BackendType from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXAddLayerMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConstantMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMulLayerMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXSoftmaxMetatype +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXTransposeMetatype from nncf.onnx.graph.nncf_graph_builder import ONNXLayerAttributes from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend from tests.cross_fw.test_templates.models import NNCFGraphToTest from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation +from tests.cross_fw.test_templates.models import NNCFGraphTransformer from tests.cross_fw.test_templates.test_quantizer_config import TemplateTestQuantizerConfig @@ -26,6 +33,9 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return ONNXMinMaxAlgoBackend() + def get_backend_type(self): + return BackendType.ONNX + @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: conv_layer_attrs = ONNXLayerAttributes(weight_attrs={1: {"shape": [4, 4, 4, 4]}}, bias_attrs={}) @@ -59,3 +69,16 @@ def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation: output_layer_attrs=ONNXLayerAttributes(), const_layer_attrs=ONNXLayerAttributes(), ) + + @pytest.fixture + def transformer_nncf_graph(self) -> NNCFGraphToTest: + return NNCFGraphTransformer( + matmul_metatype=ONNXMatMulMetatype, + softmax_metatype=ONNXSoftmaxMetatype, + mul_metatype=ONNXMulLayerMetatype, + const_metatype=ONNXConstantMetatype, + transpose_metatype=ONNXTransposeMetatype, + matmul_layer_weighted_attrs=ONNXLayerAttributes({"name": "edge_name", "shape": (1, 1, 1, 1)}), + matmul_layer_non_weighted_attrs=ONNXLayerAttributes(), + default_layer_attrs=ONNXLayerAttributes(), + ) diff --git a/tests/openvino/native/quantization/test_quantizer_config.py b/tests/openvino/native/quantization/test_quantizer_config.py index ddb3d99ead4..1be390fbce9 100644 --- a/tests/openvino/native/quantization/test_quantizer_config.py +++ b/tests/openvino/native/quantization/test_quantizer_config.py @@ -11,14 +11,21 @@ import pytest +from nncf.common.utils.backend import BackendType from nncf.openvino.graph.layer_attributes import OVLayerAttributes +from nncf.openvino.graph.metatypes.openvino_metatypes import OVConstantMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import OVMultiplyMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import OVSoftmaxMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVSumMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import OVTransposeMetatype from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend from tests.cross_fw.test_templates.models import NNCFGraphToTest from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation +from tests.cross_fw.test_templates.models import NNCFGraphTransformer from tests.cross_fw.test_templates.test_quantizer_config import TemplateTestQuantizerConfig @@ -26,6 +33,9 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return OVMinMaxAlgoBackend() + def get_backend_type(self): + return BackendType.OPENVINO + @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4), "dtype": "f32"}}) @@ -39,3 +49,14 @@ def depthwise_conv_nncf_graph(self): def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation: conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4), "dtype": "f32"}}) return NNCFGraphToTestSumAggregation(OVConvolutionMetatype, OVSumMetatype, conv_layer_attrs) + + @pytest.fixture + def transformer_nncf_graph(self) -> NNCFGraphToTest: + return NNCFGraphTransformer( + matmul_metatype=OVMatMulMetatype, + softmax_metatype=OVSoftmaxMetatype, + mul_metatype=OVMultiplyMetatype, + const_metatype=OVConstantMetatype, + transpose_metatype=OVTransposeMetatype, + matmul_layer_weighted_attrs=OVLayerAttributes({}), + ) diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 8a922a89598..b23f446c7ac 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -51,7 +51,7 @@ torchvision/vit_b_16_backend_FP32: torchvision/vit_b_16_backend_OV: metric_value: 0.80948 torchvision/vit_b_16_backend_FX_TORCH: - metric_value: 0.80702 + metric_value: 0.80922 torchvision/swin_v2_s_backend_FP32: metric_value: 0.83712 torchvision/swin_v2_s_backend_OV: diff --git a/tests/torch/fx/test_quantizer_config.py b/tests/torch/fx/test_quantizer_config.py index 5bc51e51332..b5927bc2caa 100644 --- a/tests/torch/fx/test_quantizer_config.py +++ b/tests/torch/fx/test_quantizer_config.py @@ -11,10 +11,13 @@ import pytest +import nncf.torch.graph.operator_metatypes as om +from nncf.common.utils.backend import BackendType from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend from tests.cross_fw.test_templates.models import NNCFGraphToTest from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation +from tests.cross_fw.test_templates.models import NNCFGraphTransformer from tests.cross_fw.test_templates.test_quantizer_config import TemplateTestQuantizerConfig from tests.torch.fx.helpers import get_depthwise_conv_nncf_graph from tests.torch.fx.helpers import get_single_conv_nncf_graph @@ -25,6 +28,9 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return FXMinMaxAlgoBackend() + def get_backend_type(self): + return BackendType.TORCH_FX + @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: return get_single_conv_nncf_graph() @@ -36,3 +42,13 @@ def depthwise_conv_nncf_graph(self) -> NNCFGraphToTestDepthwiseConv: @pytest.fixture def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation: return get_sum_aggregation_nncf_graph() + + @pytest.fixture + def transformer_nncf_graph(self) -> NNCFGraphToTest: + return NNCFGraphTransformer( + matmul_metatype=om.PTMatMulMetatype, + softmax_metatype=om.PTSoftmaxMetatype, + mul_metatype=om.PTMulMetatype, + const_metatype=om.PTConstNoopMetatype, + transpose_metatype=om.PTTransposeMetatype, + ) diff --git a/tests/torch/ptq/test_quantizer_config.py b/tests/torch/ptq/test_quantizer_config.py index acc8cb4002d..1b64b6cbd87 100644 --- a/tests/torch/ptq/test_quantizer_config.py +++ b/tests/torch/ptq/test_quantizer_config.py @@ -11,10 +11,14 @@ import pytest +import nncf.torch.graph.operator_metatypes as om +from nncf.common.utils.backend import BackendType from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend +from nncf.torch.graph.graph import PTNNCFGraph from tests.cross_fw.test_templates.models import NNCFGraphToTest from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation +from tests.cross_fw.test_templates.models import NNCFGraphTransformer from tests.cross_fw.test_templates.test_quantizer_config import TemplateTestQuantizerConfig from tests.torch.ptq.helpers import get_depthwise_conv_nncf_graph from tests.torch.ptq.helpers import get_single_conv_nncf_graph @@ -25,6 +29,9 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return PTMinMaxAlgoBackend() + def get_backend_type(self): + return BackendType.TORCH + @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: return get_single_conv_nncf_graph() @@ -36,3 +43,14 @@ def depthwise_conv_nncf_graph(self) -> NNCFGraphToTestDepthwiseConv: @pytest.fixture def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation: return get_sum_aggregation_nncf_graph() + + @pytest.fixture + def transformer_nncf_graph(self) -> NNCFGraphToTest: + return NNCFGraphTransformer( + matmul_metatype=om.PTMatMulMetatype, + softmax_metatype=om.PTSoftmaxMetatype, + mul_metatype=om.PTMulMetatype, + const_metatype=om.PTConstNoopMetatype, + transpose_metatype=om.PTTransposeMetatype, + nncf_graph_cls=PTNNCFGraph, + )