Skip to content

Commit

Permalink
[FX/PT/ONNX][MinMax] Weights nodes/ constant matmuls collection fix (#…
Browse files Browse the repository at this point in the history
…2944)

### Changes

* Weights nodes collection is updated for PT/FX backends
* Constant matmul check is updated for PT/FX/ONNX backends

### Reason for changes

To unify self attention quantization across all backends and align with
the OV quantization

### Related tickets

#2766 

### Tests

* test_model_type_transformer_quantization_config
* post_training_quantization/485/ + post_training_quantization/486/ is
finished successfully
  • Loading branch information
daniil-lyakhov authored Sep 27, 2024
1 parent c93676d commit da5217a
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 35 deletions.
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ def _apply_model_type_pass(
continue
if (
quantization_point.qconfig.mode != QuantizationScheme.SYMMETRIC
and node.layer_attributes is None
and not self._backend_entity.is_matmul_with_constant(node, nncf_graph)
):
quantization_point.qconfig.mode = QuantizationScheme.SYMMETRIC
nncf_logger.debug(
Expand Down
13 changes: 11 additions & 2 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,21 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
:return: List of ignored names.
"""

@staticmethod
@abstractmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]:
"""
Returns nodes that have weights.
:param nncf_graph: Instance of NNCFGraph.
:return: All nodes with weights.
"""

@abstractmethod
def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Returns true if given nncf matmul node is a matmul with a constant, False otherwise.
:param Node: Instance of NNCFNode.
:param nncf_graph: Instance of NNCFGraph.
:return: True if given nncf matmul node is a matmul with a constant, False otherwise.
"""
6 changes: 4 additions & 2 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
return set()

@staticmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [node for node in nncf_graph.get_all_nodes() if node.layer_attributes.has_weight()]

@staticmethod
Expand All @@ -268,3 +267,6 @@ def get_weight_name(nncf_graph: NNCFGraph, target_point: ONNXTargetPoint) -> str
def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool:
# If the nodes share one weight tensor, we should have only one quantizer on that
return weight_name not in quantized_weight_names

def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.metatype in self.mat_mul_metatypes and node.layer_attributes.has_weight()
6 changes: 4 additions & 2 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,16 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
ignored_names.add(node.node_name)
return ignored_names

@staticmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [
node
for node in nncf_graph.get_all_nodes()
if isinstance(node.layer_attributes, OVLayerAttributes) and node.metatype in OPERATIONS_WITH_WEIGHTS
]

def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.metatype in self.mat_mul_metatypes and node.layer_attributes is not None

@staticmethod
def get_weight_name(nncf_graph: NNCFGraph, target_point: OVTargetPoint) -> str:
node = nncf_graph.get_node_by_name(target_point.target_node_name)
Expand Down
14 changes: 11 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,18 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
return set()

@staticmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [
def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]:
weight_nodes_candidates = [
node
for node in nncf_graph.get_all_nodes()
if issubclass(node.metatype, om.PTOperatorMetatype) and node.metatype.weight_port_ids
]
weight_nodes = []
for node in weight_nodes_candidates:
if node.metatype in self.mat_mul_metatypes and not self.is_matmul_with_constant(node, nncf_graph):
continue
weight_nodes.append(node)
return weight_nodes

def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0
14 changes: 11 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,18 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
return set()

@staticmethod
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return [
def get_weight_nodes(self, nncf_graph: NNCFGraph) -> List[NNCFNode]:
weight_nodes_candidates = [
node
for node in nncf_graph.get_all_nodes()
if issubclass(node.metatype, om.PTOperatorMetatype) and node.metatype.weight_port_ids
]
weight_nodes = []
for node in weight_nodes_candidates:
if node.metatype in self.mat_mul_metatypes and not self.is_matmul_with_constant(node, nncf_graph):
continue
weight_nodes.append(node)
return weight_nodes

def is_matmul_with_constant(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.metatype in self.mat_mul_metatypes and len(get_weight_tensor_port_ids(node, nncf_graph)) > 0
49 changes: 49 additions & 0 deletions tests/cross_fw/test_templates/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,52 @@ def __init__(

original_mock_graph = create_mock_graph(nodes, edges)
self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls)


class NNCFGraphTransformer:
def __init__(
self,
matmul_metatype,
softmax_metatype,
transpose_metatype,
const_metatype,
mul_metatype,
matmul_layer_weighted_attrs=None,
matmul_layer_non_weighted_attrs=None,
default_layer_attrs=None,
nncf_graph_cls=NNCFGraph,
):
# softmax((K x Q) * scale) x V.T
nodes = [
NodeWithType("Input_1", InputNoopMetatype, layer_attributes=default_layer_attrs),
NodeWithType("W_K", const_metatype, layer_attributes=default_layer_attrs),
NodeWithType("W_Q", const_metatype, layer_attributes=default_layer_attrs),
NodeWithType("W_V", const_metatype, layer_attributes=default_layer_attrs),
NodeWithType("K", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs),
NodeWithType("Q", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs),
NodeWithType("V", matmul_metatype, layer_attributes=matmul_layer_weighted_attrs),
NodeWithType("K_Q", matmul_metatype, layer_attributes=matmul_layer_non_weighted_attrs),
NodeWithType("div", mul_metatype, layer_attributes=default_layer_attrs),
NodeWithType("softmax", softmax_metatype, layer_attributes=default_layer_attrs),
NodeWithType("transpose", transpose_metatype, layer_attributes=default_layer_attrs),
NodeWithType("SA_V", matmul_metatype, layer_attributes=matmul_layer_non_weighted_attrs),
NodeWithType("Output_1", OutputNoopMetatype, layer_attributes=default_layer_attrs),
]
node_edges = [
("Input_1", "K"),
("W_K", "K"),
("Input_1", "Q"),
("W_Q", "Q"),
("Input_1", "V"),
("W_V", "V"),
("K", "K_Q"),
("Q", "K_Q"),
("K_Q", "div"),
("div", "softmax"),
("softmax", "SA_V"),
("V", "transpose"),
("transpose", "SA_V"),
("SA_V", "Output_1"),
]
original_mock_graph = create_mock_graph(nodes, node_edges)
self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls)
112 changes: 91 additions & 21 deletions tests/cross_fw/test_templates/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

import pytest

from nncf import ModelType
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns.manager import PatternsManager
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
Expand All @@ -34,6 +37,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.passes import transform_to_inference_graph
Expand All @@ -48,6 +52,10 @@ class TemplateTestQuantizerConfig:
def get_algo_backend(self):
pass

@abstractmethod
def get_backend_type(self):
pass

def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector):
aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()]
assert len(aggrs) == 2
Expand All @@ -63,11 +71,26 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl
def get_reduction_axes(self, reducer: TensorReducerBase) -> ReductionAxes:
return reducer._reduction_axes

@staticmethod
def _transform_to_inference_graph(nncf_graph: NNCFGraph, min_max_algo: MinMaxQuantization):
return transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)

@abstractmethod
@pytest.fixture
def single_conv_nncf_graph(self) -> NNCFGraphToTest:
pass

@abstractmethod
@pytest.fixture
def transformer_nncf_graph(self) -> NNCFGraphToTest:
pass

@abstractmethod
@pytest.fixture
def depthwise_conv_nncf_graph(self) -> NNCFGraphToTestDepthwiseConv:
Expand Down Expand Up @@ -129,13 +152,7 @@ def test_default_quantizer_config(self, single_conv_nncf_graph):
min_max_algo = MinMaxQuantization()
min_max_algo._backend_entity = self.get_algo_backend()
nncf_graph = single_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo)
q_setup = min_max_algo._get_quantizer_setup(
nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern()
)
Expand Down Expand Up @@ -184,13 +201,7 @@ def test_quantizer_config_from_ptq_params_for_CPU(
)
min_max_algo._backend_entity = self.get_algo_backend()
nncf_graph = single_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo)
if signed_weights is False or signed_activations in [True, False]: # Incompatible with HW CPU config
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -227,13 +238,7 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph
min_max_algo = MinMaxQuantization()
min_max_algo._backend_entity = self.get_algo_backend()
nncf_graph = depthwise_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo)
q_setup = min_max_algo._get_quantizer_setup(
nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern()
)
Expand All @@ -253,6 +258,71 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph
if quantization_point.is_activation_quantization_point():
assert quantization_point.qconfig == activation_default_config

REF_TRANSFORMER_SETUP_STATE = {
"quantization_points": {
4: {
"qip": {"target_node_name": "/K_0", "input_port_id": None},
"qip_class": "ActivationQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": None, "per_channel": False},
"directly_quantized_operator_node_names": ["/K_Q_0"],
},
5: {
"qip": {"target_node_name": "/Q_0", "input_port_id": None},
"qip_class": "ActivationQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": None, "per_channel": False},
"directly_quantized_operator_node_names": ["/K_Q_0"],
},
6: {
"qip": {"target_node_name": "/Input_1_0", "input_port_id": None},
"qip_class": "ActivationQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "asymmetric", "signedness_to_force": None, "per_channel": False},
"directly_quantized_operator_node_names": ["/K_0", "/Q_0", "/V_0"],
},
8: {
"qip": {"target_node_name": "/K_0"},
"qip_class": "WeightQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True},
"directly_quantized_operator_node_names": ["/K_0"],
},
9: {
"qip": {"target_node_name": "/Q_0"},
"qip_class": "WeightQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True},
"directly_quantized_operator_node_names": ["/Q_0"],
},
10: {
"qip": {"target_node_name": "/V_0"},
"qip_class": "WeightQuantizationInsertionPoint",
"qconfig": {"num_bits": 8, "mode": "symmetric", "signedness_to_force": True, "per_channel": True},
"directly_quantized_operator_node_names": ["/V_0"],
},
},
"unified_scale_groups": {},
"shared_input_operation_set_groups": {0: [4, 5], 1: [8, 9, 10, 6]},
}

def test_model_type_transformer_quantization_config(self, transformer_nncf_graph):
min_max_algo = MinMaxQuantization(model_type=ModelType.TRANSFORMER)
min_max_algo._backend_entity = self.get_algo_backend()
nncf_graph = transformer_nncf_graph.nncf_graph
inference_nncf_graph = self._transform_to_inference_graph(nncf_graph, min_max_algo)
hw_patterns = PatternsManager.get_full_hw_pattern_graph(
backend=self.get_backend_type(), device=TargetDevice.ANY, model_type=ModelType.TRANSFORMER
)
ignored_patterns = PatternsManager.get_full_ignored_pattern_graph(
backend=self.get_backend_type(), device=TargetDevice.ANY, model_type=ModelType.TRANSFORMER
)
q_setup = min_max_algo._get_quantizer_setup(
nncf_graph, inference_nncf_graph, hw_patterns=hw_patterns, ignored_patterns=ignored_patterns
)
min_max_algo._apply_model_type_pass(ModelType.TRANSFORMER, q_setup, nncf_graph)

state = q_setup.get_state()
state["quantization_points"][6]["directly_quantized_operator_node_names"] = sorted(
state["quantization_points"][6]["directly_quantized_operator_node_names"]
)
assert state == self.REF_TRANSFORMER_SETUP_STATE

@pytest.mark.parametrize(
"range_estimator_params", [RangeEstimatorParametersSet.MINMAX, RangeEstimatorParametersSet.MEAN_MINMAX]
)
Expand Down
23 changes: 23 additions & 0 deletions tests/onnx/quantization/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@

import pytest

from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXAddLayerMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConstantMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMulLayerMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXSoftmaxMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXTransposeMetatype
from nncf.onnx.graph.nncf_graph_builder import ONNXLayerAttributes
from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend
from tests.cross_fw.test_templates.models import NNCFGraphToTest
from tests.cross_fw.test_templates.models import NNCFGraphToTestDepthwiseConv
from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation
from tests.cross_fw.test_templates.models import NNCFGraphTransformer
from tests.cross_fw.test_templates.test_quantizer_config import TemplateTestQuantizerConfig


class TestQuantizerConfig(TemplateTestQuantizerConfig):
def get_algo_backend(self):
return ONNXMinMaxAlgoBackend()

def get_backend_type(self):
return BackendType.ONNX

@pytest.fixture
def single_conv_nncf_graph(self) -> NNCFGraphToTest:
conv_layer_attrs = ONNXLayerAttributes(weight_attrs={1: {"shape": [4, 4, 4, 4]}}, bias_attrs={})
Expand Down Expand Up @@ -59,3 +69,16 @@ def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation:
output_layer_attrs=ONNXLayerAttributes(),
const_layer_attrs=ONNXLayerAttributes(),
)

@pytest.fixture
def transformer_nncf_graph(self) -> NNCFGraphToTest:
return NNCFGraphTransformer(
matmul_metatype=ONNXMatMulMetatype,
softmax_metatype=ONNXSoftmaxMetatype,
mul_metatype=ONNXMulLayerMetatype,
const_metatype=ONNXConstantMetatype,
transpose_metatype=ONNXTransposeMetatype,
matmul_layer_weighted_attrs=ONNXLayerAttributes({"name": "edge_name", "shape": (1, 1, 1, 1)}),
matmul_layer_non_weighted_attrs=ONNXLayerAttributes(),
default_layer_attrs=ONNXLayerAttributes(),
)
Loading

0 comments on commit da5217a

Please sign in to comment.