diff --git a/docs/usage/post_training_compression/weights_compression/Usage.md b/docs/usage/post_training_compression/weights_compression/Usage.md index 223805e5979..3a4cb0eef11 100644 --- a/docs/usage/post_training_compression/weights_compression/Usage.md +++ b/docs/usage/post_training_compression/weights_compression/Usage.md @@ -11,8 +11,8 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. OpenVINO backend also supports 4 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM, NF4, E2M1. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer and weight are quantized to it [asymmetrically](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point. In case of E2M1 mode - [e2m1](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data type without zero point and has 8bit [E8M0](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) scale. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). -All embeddings, convolutions and last linear layers are always compressed to 8-bit integer data type. To quantize embeddings and last linear layers to 4-bit, use `all_layers=True`. -Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type. +All embeddings, convolutions and last linear layers are always compressed to a backup mode, which is "INT8_ASYM", by default. To quantize embeddings and last linear layers to 4-bit, use `all_layers=True`. +Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to a backup mode. OpenVINO backend supports 3 backup modes: INT8_SYM, INT8_ASYM, and NONE, which retains the original floating-point precision of the model weights. Backup mode is supported only for mixed-precision weight quantization. ### User guide @@ -37,6 +37,13 @@ from nncf import compress_weights, CompressWeightsMode compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM) # model is openvino.Model object ``` +- Compress weights to NF4 with group size = 128, except embeddings, convolutions and last linear layers - they are remain in original floating-point precision. + +```python +from nncf import compress_weights, BackupMode, CompressWeightsMode +compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4, backup_mode=BackupMode.NONE) # model is openvino.Model object +``` + - Generally, `INT4_SYM` mode is the fastest mixed-precision mode, but it may lead to a significant accuracy degradation or perplexity increase. Compressing weights asymmetrically (`INT4_ASYM` mode) is the way to increase accuracy, however in turns it slows down inference a bit. If the accuracy or perplexity is still not satisfying, there are 2 more hyper-parameters to tune: `group_size` and `ratio`. Please refer to the [example](https://github.com/openvinotoolkit/nncf/blob/develop/examples/llm_compression/openvino/tiny_llama_find_hyperparams) how to automatically tune these parameters. diff --git a/nncf/__init__.py b/nncf/__init__.py index 64449545310..88ad2cfb09e 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -32,6 +32,7 @@ from nncf.errors import UnsupportedModelError as UnsupportedModelError from nncf.errors import UnsupportedVersionError as UnsupportedVersionError from nncf.errors import ValidationError as ValidationError +from nncf.parameters import BackupMode as BackupMode from nncf.parameters import CompressWeightsMode as CompressWeightsMode from nncf.parameters import DropType as DropType from nncf.parameters import ModelType as ModelType diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 8061f2ab2f4..403e36f6e39 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -29,6 +29,7 @@ from nncf.experimental.torch.fx.transformations import apply_quantization_transformations from nncf.experimental.torch.fx.transformations import revert_quantization_transformations from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import QuantizationMode @@ -124,6 +125,7 @@ def compress_weights_impl( scale_estimation: bool, gptq: bool, lora_correction: bool, + backup_mode: BackupMode, advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> torch.fx.GraphModule: """ @@ -142,6 +144,7 @@ def compress_weights_impl( scale_estimation, gptq, lora_correction, + backup_mode, advanced_parameters, ) shared_constants_unification_transformation(model) diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 2ca9d9bad40..cbf210ebc20 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -29,6 +29,7 @@ from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies from nncf.openvino.rt_info import dump_parameters +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType @@ -379,6 +380,7 @@ def compress_weights_impl( scale_estimation: bool, gptq: bool, lora_correction: bool, + backup_mode: BackupMode, advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> ov.Model: """ @@ -398,6 +400,7 @@ def compress_weights_impl( scale_estimation, gptq, lora_correction, + backup_mode, advanced_parameters, ) graph = NNCFGraphFactory.create(model) diff --git a/nncf/parameters.py b/nncf/parameters.py index d4d43ea7d87..b7d27ef8018 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -96,6 +96,23 @@ class CompressWeightsMode(StrEnum): E2M1 = "e2m1" +@api(canonical_alias="nncf.BackupMode") +class BackupMode(StrEnum): + """ + Defines a backup mode for weight compression. + :param NONE: Stands for original floating-point precision of the model weights. + In this mode, weights are retained in their original precision without any quantization. + :param INT8_SYM: Stands for 8-bit integer symmetric quantization without zero point. + https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization + :param INT8_ASYM: Stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point. + https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization + """ + + NONE = "none" + INT8_SYM = "int8_sym" + INT8_ASYM = "int8_asym" + + @api(canonical_alias="nncf.SensitivityMetric") class SensitivityMetric(StrEnum): """ diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 24fc509c85e..0190de36471 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -27,6 +27,7 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.common.utils.helpers import create_table +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric from nncf.quantization.advanced_parameters import AdvancedCompressionParameters @@ -42,7 +43,6 @@ from nncf.scopes import IgnoredScope from nncf.scopes import get_ignored_node_names_from_ignored_scope from nncf.tensor import Tensor -from nncf.tensor.definitions import TensorDataType TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -69,6 +69,7 @@ def __init__( scale_estimation: bool, gptq: bool, lora_correction: bool, + backup_mode: BackupMode = BackupMode.INT8_ASYM, advanced_parameters: Optional[AdvancedCompressionParameters] = None, ): """ @@ -79,15 +80,15 @@ def __init__( with a typical non-fixed zero point. INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. Weights are quantized to a primary precision symmetrically without zero point. - All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM, - by default. All others are quantized whether to 4-bit integer or to a backup precision depending on + All embeddings and the last layer are always compressed to a backup_mode, which is INT8_ASYM, + by default. All others are quantized whether to 4-bit integer or to a backup_mode depending on criteria and the given ratio. INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically with a typical non-fixed zero point. NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point. E2M1 is the same as INT4_SYM mode, but primary precision is E2M1 data type without zero point. :param ratio: the ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 - and the rest to INT8_ASYM). + and the rest to backup_mode). :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). The value -1 means no grouping. :param ignored_scope: An ignored scope that defined the list of model control @@ -102,6 +103,11 @@ def __init__( :param scale_estimation: determines whether to use or not scale estimation for 4 bit layers. :param gptq: determines whether to use or not GPTQ algorithm. :param lora_correction: determines whether to use or not LoRA Correction algorithm. + :param backup_mode: Defines a backup mode for mixed-precision weight compression. + NONE stands for original floating-point precision of the model weights. + In this mode, weights are retained in their original precision without any quantization. + INT8_SYM stands for 8-bit integer symmetric quantization without zero point. + INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point. :param advanced_parameters: advanced parameters for algorithms in compression pipeline. """ super().__init__() @@ -119,6 +125,7 @@ def __init__( self._scale_estimation = scale_estimation self._gptq = gptq self._lora_correction = lora_correction + self._backup_mode = backup_mode self._advanced_parameters = ( advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters() ) @@ -265,7 +272,10 @@ def _proportion_str(num_weights_list: List[int], total_num_weights: int, total_n return f"{percentage:.0f}% ({len(num_weights_list)} / {total_num_params})" def _get_bitwidth_distribution_str( - self, all_params: List[WeightCompressionParameters], ratio_defining_params: List[WeightCompressionParameters] + self, + all_params: List[WeightCompressionParameters], + ratio_defining_params: List[WeightCompressionParameters], + ignored_scope_weight_statistics: List[int], ) -> str: """ Generates a table that shows the ratio of weights quantized to different number of bits. @@ -273,27 +283,32 @@ def _get_bitwidth_distribution_str( :param all_params: Information about each weight node. :param ratio_defining_params: Information about weights that are used for calculating ratio between primary and backup precisions. + :param ignored_scope_weight_statistics: Information about weight nodes from IgnoredScope. :return: A string containing the table. """ - num_bits_vs_num_weights_map = {} + dtype_vs_num_weights_map = {} ratio_defining_weight_names = set(wp.weight_name for wp in ratio_defining_params) for data in all_params: - num_bits = data.compression_config.num_bits - n_total, n_ratio_defining = num_bits_vs_num_weights_map.get(num_bits, ([], [])) + dtype = data.compression_config.mode if data.compression_config is not None else "float" + n_total, n_ratio_defining = dtype_vs_num_weights_map.get(dtype, ([], [])) if data.weight_name in ratio_defining_weight_names: n_ratio_defining.append(data.num_weights) n_total.append(data.num_weights) - num_bits_vs_num_weights_map[num_bits] = (n_total, n_ratio_defining) + dtype_vs_num_weights_map[dtype] = (n_total, n_ratio_defining) + + if ignored_scope_weight_statistics: + n_total, n_ratio_defining = dtype_vs_num_weights_map.get("float", ([], [])) + dtype_vs_num_weights_map["float"] = (n_total + ignored_scope_weight_statistics, n_ratio_defining) num_ratio_defining_weights = sum(ws.num_weights for ws in ratio_defining_params) num_ratio_defining_params = len(ratio_defining_params) - num_total_weights = sum(ws.num_weights for ws in all_params) - num_params = len(all_params) - num_bits_vs_num_weights_map = OrderedDict(sorted(num_bits_vs_num_weights_map.items(), reverse=True)) + num_total_weights = sum(ws.num_weights for ws in all_params) + sum(ignored_scope_weight_statistics) + num_params = len(all_params) + len(ignored_scope_weight_statistics) + dtype_vs_num_weights_map = OrderedDict(sorted(dtype_vs_num_weights_map.items(), reverse=True)) # Table creation - header = ["Num bits (N)", "% all parameters (layers)", "% ratio-defining parameters (layers)"] + header = ["Weight compression mode", "% all parameters (layers)", "% ratio-defining parameters (layers)"] rows = [] - for bitwidth, (n_total, n_ratio_defining) in num_bits_vs_num_weights_map.items(): + for bitwidth, (n_total, n_ratio_defining) in dtype_vs_num_weights_map.items(): rows.append( [ bitwidth, @@ -306,6 +321,35 @@ def _get_bitwidth_distribution_str( pretty_string = f"Statistics of the bitwidth distribution:\n{table}" return pretty_string + def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph) -> List[int]: + """ + Collect the weight statistics for nodes in the ignored scope. + + :param model: Model for statistics collection. + :param graph: Model graph. + :return: A list of weight sizes for the ignored nodes. + """ + ignored_names = get_ignored_node_names_from_ignored_scope(self._ignored_scope, graph, strict=False) + weighted_metatypes = ( + self._backend_entity.matmul_metatypes + + self._backend_entity.embedding_metatypes + + self._backend_entity.convolution_metatypes + ) + ignored_scope_weight_statistics = [] + for node_name in ignored_names: + node = graph.get_node_by_name(node_name) + is_node_with_weights = self._backend_entity.is_node_with_weights(node, graph) + if not is_node_with_weights or node.metatype not in weighted_metatypes: + continue + for _, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph): + weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) + if weight_dtype.is_float(): + continue + weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) + weight_size = reduce(operator.mul, weight_shape, 1) + ignored_scope_weight_statistics.append(weight_size) + return ignored_scope_weight_statistics + def apply( self, model: TModel, @@ -315,7 +359,6 @@ def apply( ) -> TModel: self._set_backend_entity(model) nodes_to_compress = self._get_nodes_to_compress(graph) - activations = {} if dataset is not None and self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR: activations = self._get_activations(dataset, self._subset_size, nodes_to_compress, graph, model) @@ -332,12 +375,7 @@ def apply( continue weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) - if weight_dtype not in [ - TensorDataType.float16, - TensorDataType.bfloat16, - TensorDataType.float32, - TensorDataType.float64, - ]: + if not weight_dtype.is_float(): continue weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) weight_size = reduce(operator.mul, weight_shape, 1) @@ -350,24 +388,57 @@ def apply( and len(reduction_axes) != 1 ): # NNCF supports multiple reduction axes only for ops with group_size != -1. - # Convolution ops are always quantized to 8-bits (without groups). + # Convolution ops are always kept in backup mode. # Embedding layers are quantized to 4-bits only if all_layers=True. # MatMul ops can't have multiple reduction axes. nncf_logger.warning( f"Weight compression expects a single reduction axis, but {len(reduction_axes)} given. " f"Weight shape: {weight_shape}, reduction axes: {reduction_axes}, " - f"node name: {node.node_name}. The node will be asymmetrically quantized to 8 bits." + f"node name: {node.node_name}. The node will be in {self._backup_mode} mode." ) + if self._backup_mode == BackupMode.NONE: + wc_config = None + else: + mode = ( + CompressWeightsMode.INT8_ASYM + if self._backup_mode == BackupMode.INT8_ASYM + else CompressWeightsMode.INT8_SYM + ) + wc_config = WeightCompressionConfig(mode=mode) weight_params = WeightCompressionParameters( - weight_name, node, weight_port_id, weight_size, reduction_axes + weight_name, node, weight_port_id, weight_size, reduction_axes, wc_config ) all_weight_params.append(weight_params) weight_names.add(weight_name) ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_shared) self._set_weight_compression_config(ratio_defining_params, model, graph, activations) - nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params)) + + ignored_scope_weight_statistics = self._get_ignored_scope_weight_statistics(model, graph) + nncf_logger.info( + self._get_bitwidth_distribution_str( + all_weight_params, ratio_defining_params, ignored_scope_weight_statistics + ) + ) + + if self._backup_mode == BackupMode.NONE: + # Filter all_weight_params and nodes_to_compress by excluding nodes + # that should remain in their original floating-point precision + nodes_names_to_exclude = { + w_params.node_with_weight.node_name + for w_params in all_weight_params + if w_params.compression_config is None + } + all_weight_params = list( + filter( + lambda w_params: w_params.node_with_weight.node_name not in nodes_names_to_exclude, + all_weight_params, + ) + ) + nodes_to_compress = list( + filter(lambda node: node.node_name not in nodes_names_to_exclude, nodes_to_compress) + ) if self._awq and activations is not None and self._mode != CompressWeightsMode.E2M1: awq_params = self._advanced_parameters.awq_params @@ -446,6 +517,7 @@ def apply( "scale_estimation": self._scale_estimation, "gptq": self._gptq, "lora_correction": self._lora_correction, + "backup_mode": self._backup_mode.value, "advanced_parameters": convert_to_dict_recursively(self._advanced_parameters), }, algo_name="weight_compression", diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 6271bd6c255..aec029c55e7 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -161,6 +161,10 @@ def apply( for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph): target_node_names.append(weight_op_friendly_name) + # skip node if it is in IgnoredScope or should not be compressed + if target_node_names[-1] not in name_mapping: + continue + weight_params = self._all_weight_params[name_mapping[target_node_names[-1]]] if weight_params.compression_config.num_bits != 4: diff --git a/nncf/quantization/algorithms/weight_compression/config.py b/nncf/quantization/algorithms/weight_compression/config.py index 3e1a460b012..56dbc24f2e2 100644 --- a/nncf/quantization/algorithms/weight_compression/config.py +++ b/nncf/quantization/algorithms/weight_compression/config.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from dataclasses import field from typing import Optional, Tuple, TypeVar import numpy as np @@ -64,7 +65,7 @@ class WeightCompressionParameters: weight_port_id: int num_weights: np.uint64 reduction_axes: Tuple[int, ...] - compression_config = WeightCompressionConfig() + compression_config: Optional[WeightCompressionConfig] = field(default_factory=WeightCompressionConfig) def __post_init__(self): # Explicitly cast num_weights to avoid overflow on finding total number of weights. diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 63f90b2677c..fd746fd1be7 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -103,7 +103,7 @@ def _calc_weight_sensitivity(self, weight_param: WeightCompressionParameters) -> weight = self._backend_entity.get_weight( weight_param.node_with_weight, weight_param.weight_port_id, self._model, self._graph ) - backup_config = weight_param.compression_config + backup_config = WeightCompressionConfig() reduction_axes = weight_param.reduction_axes int_error = get_integer_quantization_error(weight, reduction_axes, backup_config) eps = fns.finfo(weight).eps @@ -169,7 +169,7 @@ def _calc_weight_sensitivity(self, weight_param: WeightCompressionParameters) -> weight = self._backend_entity.get_weight( weight_param.node_with_weight, weight_param.weight_port_id, self._model, self._graph ) - backup_config = weight_param.compression_config + backup_config = WeightCompressionConfig() reduction_axes = weight_param.reduction_axes orig_shape = weight.shape diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 1f633458e98..90d69b73747 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -22,6 +22,7 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.data import Dataset +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType @@ -394,6 +395,7 @@ def compress_weights( scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, lora_correction: Optional[bool] = None, + backup_mode: Optional[BackupMode] = None, advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> TModel: """ @@ -444,6 +446,12 @@ def compress_weights( :type gptq: bool :param lora_correction: Indicates whether to use Lora Correction algorithm. :type lora_correction: bool + :param backup_mode: Defines a backup mode for mixed-precision weight compression. + NONE stands for original floating-point precision of the model weights. + In this mode, weights are retained in their original precision without any quantization. + INT8_SYM stands for 8-bit integer symmetric quantization without zero point. + INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point. + :type backup_mode: nncf.BackupMode :param advanced_parameters: Advanced parameters for compression algorithms. :type advanced_parameters: nncf.AdvancedCompressionParameters :return: The non-trainable model with compressed weights. @@ -474,6 +482,9 @@ def compress_weights( "Set them to None." ) + if backup_mode is not None: + raise AttributeError("Torch backend does not support backup_mode option.") + if is_wrapped_model(model): if not model.nncf.trace_parameters: raise ValueError( @@ -501,6 +512,9 @@ def compress_weights( f"but given {mode.value} mode." ) + if backup_mode is not None: + raise AttributeError("TorchFX backend does not support backup_mode option.") + if any((awq, scale_estimation, gptq, lora_correction)): raise AttributeError( "TorchFX backend does not support 'awq', 'scale_estimation', 'gptq'," @@ -537,9 +551,13 @@ def compress_weights( group_size = -1 if ratio != 1 or group_size != -1: raise AttributeError( - "INT8 mode assumes per-channel quantization of all layers in 8 bit. " + "INT8 modes assume per-channel quantization of all layers in 8 bit. " "Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden" ) + + if backup_mode is not None: + raise AttributeError("INT8 modes do not support the `backup_mode` option") + options = { "all_layers": all_layers, "sensitivity_metric": sensitivity_metric, @@ -577,6 +595,8 @@ def compress_weights( if dataset is None else SensitivityMetric.MAX_ACTIVATION_VARIANCE ) + if backup_mode is None: + backup_mode = BackupMode.INT8_ASYM if ratio != 1 and dataset is None and sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR: raise AttributeError( f"Mixed precision selection based on the given sensitivity metric={sensitivity_metric.value} requires " @@ -604,6 +624,7 @@ def compress_weights( scale_estimation, gptq, lora_correction, + backup_mode, advanced_parameters, ) diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index 447a6dd8bb5..5d2df4ac035 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -37,6 +37,12 @@ class TensorDataType(Enum): int64 = auto() uint8 = auto() + def is_float(self): + """ + :return: True if the tensor data type is a floating-point type, else False. + """ + return self in [TensorDataType.float16, TensorDataType.bfloat16, TensorDataType.float32, TensorDataType.float64] + class TensorDeviceType(Enum): """ diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index d15b2890efb..23cb451f5fe 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -18,6 +18,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import QuantizationMode @@ -94,6 +95,7 @@ def compress_weights_impl( scale_estimation: bool, gptq: bool, lora_correction: bool, + backup_mode: BackupMode, advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> torch.nn.Module: """ @@ -112,6 +114,7 @@ def compress_weights_impl( scale_estimation, gptq, lora_correction, + backup_mode, advanced_parameters, ) graph = NNCFGraphFactory.create(model) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 4ec30282bbd..75b7cf9c7fc 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -27,6 +27,7 @@ from nncf.errors import ValidationError from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase from nncf.openvino.graph.node_utils import get_const_value +from nncf.parameters import BackupMode from nncf.quantization import compress_weights from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters as LoraParams @@ -704,6 +705,9 @@ def test_raise_error_channel_size_is_not_divisible_by_group_size(): {"lora_correction": True}, {"gptq": True}, {"awq": True}, + {"backup_mode": BackupMode.NONE}, + {"backup_mode": BackupMode.INT8_ASYM}, + {"backup_mode": BackupMode.INT8_SYM}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): @@ -1049,6 +1053,30 @@ def test_one_dimentional_samples(mode): assert op.get_shape() == [sz, 1] +def test_awq_with_ignored_scope(): + model = AWQMatmulModel().ov_model + sz = 8 + n_samples = 10 + dataset = Dataset([np.ones([i + 1, sz]) for i in range(n_samples)]) + + compressed_model = compress_weights( + model, + mode=CompressWeightsMode.INT4_ASYM, + ratio=1.0, + group_size=-1, + dataset=dataset, + awq=True, + ignored_scope=IgnoredScope(names=["MatMul_6"]), + ) + + act_num = 0 + num_compressed = 8 + for op in compressed_model.get_ops(): + if op.get_type_name() == "Constant" and op.get_element_type() == ov.Type.u4: + act_num += 1 + assert act_num == num_compressed + + def get_shape_for_second_input(op_with_weights: ov.Node) -> List[int]: return list(op_with_weights.inputs()[1].get_shape()) @@ -1245,3 +1273,83 @@ def test_lora_with_mixed_precision(): op_name = op.get_friendly_name() if op.get_type_name() == "Constant" and ("/zero_point" in op_name or "/scale" in op_name): assert op.get_shape() == [sz, 1] + + +@pytest.mark.parametrize("backup_mode", [BackupMode.NONE, BackupMode.INT8_ASYM, BackupMode.INT8_SYM]) +def test_data_free_compression_with_backup_mode(backup_mode): + model = AWQMatmulModel().ov_model + compressed_model = compress_weights( + model, + mode=CompressWeightsMode.NF4, + ratio=0.7, + group_size=-1, + backup_mode=backup_mode, + ) + act_num = 0 + num_compressed = 3 + if backup_mode == BackupMode.INT8_ASYM: + backup_ov_mode = ov.Type.u8 + elif backup_mode == BackupMode.INT8_SYM: + backup_ov_mode = ov.Type.i8 + else: + backup_ov_mode = ov.Type.f32 + for op in compressed_model.get_ops(): + if op.get_type_name() == "Constant": + if op.get_element_type() == ov.Type.nf4: + act_num += 1 + elif "/scale" in op.get_friendly_name(): + assert op.get_element_type() == ov.Type.f16 + else: + assert op.get_element_type() == backup_ov_mode + assert act_num == num_compressed + + +@pytest.mark.parametrize("backup_mode", [BackupMode.NONE, BackupMode.INT8_ASYM, BackupMode.INT8_SYM]) +@pytest.mark.parametrize( + ("params", "num_compressed"), + ( + ({"all_layers": True}, 8), + ({"all_layers": False}, 6), + ({"sensitivity_metric": SensitivityMetric.HESSIAN_INPUT_ACTIVATION}, 6), + ({"sensitivity_metric": SensitivityMetric.MEAN_ACTIVATION_VARIANCE}, 6), + ({"sensitivity_metric": SensitivityMetric.MAX_ACTIVATION_VARIANCE}, 6), + ({"sensitivity_metric": SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE}, 6), + ({"scale_estimation": True}, 6), + ({"lora_correction": True}, 6), + ({"gptq": True}, 6), + ({"awq": True}, 6), + ), +) +def test_data_based_compression_with_backup_mode(backup_mode, params, num_compressed): + model = AWQMatmulModel().ov_model + sz = 8 + n_samples = 10 + dataset = Dataset([np.ones([i + 1, sz]) for i in range(n_samples)]) + + compressed_model = compress_weights( + model, + mode=CompressWeightsMode.INT4_ASYM, + ratio=0.8, + group_size=-1, + dataset=dataset, + backup_mode=backup_mode, + **params, + ) + act_num = 0 + if backup_mode == BackupMode.INT8_ASYM: + backup_ov_mode = ov.Type.u8 + elif backup_mode == BackupMode.INT8_SYM: + backup_ov_mode = ov.Type.i8 + else: + backup_ov_mode = ov.Type.f32 + for op in compressed_model.get_ops(): + if op.get_type_name() == "Constant": + if op.get_element_type() == ov.Type.u4: + act_num += 1 + elif "/scale" in op.get_friendly_name(): + assert op.get_element_type() == ov.Type.f16 + elif "_lora_" in op.get_friendly_name(): + assert op.get_element_type() == ov.Type.u8 + else: + assert op.get_element_type() == backup_ov_mode + assert act_num == num_compressed diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index df3f7d76838..91e7e4be220 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -40,3 +40,7 @@ tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV: num_int4: 11 num_int8: 290 metrics_xfail_reason: "Issue-148819" +tinyllama_awq_backup_mode_none_backend_OV: + metric_value: 0.84793 + num_int4: 208 + num_int8: 0 \ No newline at end of file diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 09b305e16c1..5f5080fee77 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -15,6 +15,7 @@ import nncf from nncf import ModelType from nncf import QuantizationPreset +from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric from nncf.quantization.advanced_parameters import AdvancedCompressionParameters @@ -498,6 +499,21 @@ "params": {"is_stateful": True}, "backends": [BackendType.OV], }, + { + "reported_name": "tinyllama_awq_backup_mode_none", + "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", + "pipeline_cls": LMWeightCompression, + "compression_params": { + "group_size": 64, + "ratio": 0.8, + "all_layers": True, + "backup_mode": BackupMode.NONE, + "mode": CompressWeightsMode.INT4_ASYM, + "awq": True, + "ignored_scope": nncf.IgnoredScope(types=["Gather"]), + }, + "backends": [BackendType.OV], + }, ] diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 1d5012d5d57..20793e31493 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -15,6 +15,7 @@ import torch from torch._export import capture_pre_autograd_graph +from nncf import BackupMode from nncf import CompressWeightsMode from nncf.common.factory import NNCFGraphFactory from nncf.data.dataset import Dataset @@ -208,6 +209,9 @@ def test_compress_weights_functional_model(mode): {"scale_estimation": True}, {"lora_correction": True}, {"dataset": Dataset([1])}, + {"backup_mode": BackupMode.NONE}, + {"backup_mode": BackupMode.INT8_ASYM}, + {"backup_mode": BackupMode.INT8_SYM}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 5e4ca75e128..dee60e92e5f 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -13,6 +13,7 @@ import torch import torch.nn.functional as F +from nncf import BackupMode from nncf import CompressWeightsMode from nncf import SensitivityMetric from nncf.quantization import compress_weights @@ -214,6 +215,9 @@ def forward(self, input): {"awq": True}, {"scale_estimation": True}, {"lora_correction": True}, + {"backup_mode": BackupMode.NONE}, + {"backup_mode": BackupMode.INT8_ASYM}, + {"backup_mode": BackupMode.INT8_SYM}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params):