diff --git a/docs/compression_algorithms/CompressWeights.md b/docs/compression_algorithms/CompressWeights.md
index 58b76a4d64f..15bd2f2059f 100644
--- a/docs/compression_algorithms/CompressWeights.md
+++ b/docs/compression_algorithms/CompressWeights.md
@@ -8,22 +8,30 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod
#### Supported modes
-By default, weights are compressed to 8-bit integer data type - "INT8" mode.
+By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode.
OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is unsigned 4-bit integer and weights are quantized to it [symmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) with a fixed zero point equals to 8. In case of INT4_ASYM mode - also unsigned 4-bit integer, but weight are quantized to it [asymmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point.
All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale).
All embeddings and last linear layers are always compressed to 8-bit integer data type.
-Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit integer data type.
+Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type.
#### User guide
-- Compress weights to 8-bit integer data type.
+- Compress weights asymmetrically to 8-bit integer data type.
```python
from nncf import compress_weights
compressed_model = compress_weights(model)
```
-- Compress weights symmetrically to 4-bit integer data type with group size = 128, except embeddings and last linear layers - they are compressed to 8-bit integer data type.
+- Compress weights symmetrically to 8-bit integer data type.
+
+```python
+from nncf import compress_weights
+from nncf import CompressWeightsMode
+compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8_SYM)
+```
+
+- Compress weights symmetrically to 4-bit integer data type with group size = 128, except embeddings and last linear layers - they are compressed asymmetrically to 8-bit integer data type.
```python
from nncf import compress_weights
@@ -36,7 +44,7 @@ compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM)
If the accuracy or perplexity is still not satisfying, there are 2 more hyper-parameters to tune: `group_size` and `ratio`.
Lower group size and less ratio of 4-bit layers usually improve accuracy at the sacrifice of inference speed.
Below is the example how to compress weights of 90% of layers to 4-bit integer asymmetrically with the group size 64, and
- the rest of layers to 8-bit integer data type. The same parametrization is applicable for `INT4_SYM` mode.
+ the rest of layers to 8-bit asymmetric integer data type. The same parametrization is applicable for `INT4_SYM` mode.
```python
from nncf import compress_weights
@@ -45,7 +53,7 @@ compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_ASYM, g
```
- `NF4` mode can be considered for improving accuracy, but currently models quantized to nf4 should not be faster models
- quantized to 8-bit integer. Here's the example how to compress weights to nf4 data type with group size = 128.
+ quantized to 8-bit asymmetric integer. Here's the example how to compress weights to nf4 data type with group size = 128.
Different `group_size` and `ratio` are also supported.
```python
@@ -79,7 +87,7 @@ Here is the perplexity and model size before and after weight compression for di
databricks/dolly-v2-3b |
- int8 |
+ int8_asym |
5.07 |
0.05 |
2.6 |
@@ -107,7 +115,7 @@ Here is the perplexity and model size before and after weight compression for di
facebook/opt-6.7b |
- int8 |
+ int8_asym |
4.27 |
0.01 |
6.2 |
@@ -135,7 +143,7 @@ Here is the perplexity and model size before and after weight compression for di
meta-llama/Llama-2-7b-chat-hf |
- int8 |
+ int8_asym |
3.29 |
0.01 |
6.3 |
@@ -163,7 +171,7 @@ Here is the perplexity and model size before and after weight compression for di
togethercomputer/RedPajama-INCITE-7B-Instruct |
- int8 |
+ int8_asym |
4.17 |
0.02 |
6.4 |
@@ -191,7 +199,7 @@ Here is the perplexity and model size before and after weight compression for di
meta-llama/Llama-2-13b-chat-hf |
- int8 |
+ int8_asym |
2.91 |
0 |
12.1 |
@@ -218,7 +226,7 @@ Here is the perplexity and model size before and after weight compression for di
- The algorithm is supported for OpenVINO and PyTorch models.
- The compression applies in-place.
- The compressed model is not trainable.
-- INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
+- INT8_SYM, INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
- NF4 support is experimental - models quantized to nf4 should not be faster models quantized to 8-bit integer.
#### Additional resources
diff --git a/nncf/parameters.py b/nncf/parameters.py
index adbcfb2a5dc..97ccea267be 100644
--- a/nncf/parameters.py
+++ b/nncf/parameters.py
@@ -62,10 +62,15 @@ class DropType(Enum):
class CompressWeightsMode(Enum):
"""
Defines a mode for weight compression.
- :param INT8: Stands for 8-bit integer quantization of all weights.
+ :param INT8_SYM: Stands for 8-bit integer symmetric quantization of all weights.
+ Weights are quantized symmetrically with a fixed zero point equals to 128.
+ https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization
+ :param INT8_ASYM: The same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
+ https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization
:param INT4_SYM: Stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization
@@ -73,9 +78,12 @@ class CompressWeightsMode(Enum):
with a typical non-fixed zero point.
https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization
:param NF4: The the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
+ :param INT8: Mode is deprecated and will be removed in future releases. Please use `INT8_ASYM` instead.
"""
- INT8 = "int8"
+ INT8_SYM = "int8_sym"
+ INT8_ASYM = "int8_asym"
INT4_SYM = "int4_sym"
INT4_ASYM = "int4_asym"
NF4 = "nf4"
+ INT8 = "int8" # Deprecated mode
diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py
index 867a253993d..b1596fb8028 100644
--- a/nncf/quantization/algorithms/weight_compression/algorithm.py
+++ b/nncf/quantization/algorithms/weight_compression/algorithm.py
@@ -54,17 +54,20 @@ def __init__(
):
"""
:param mode: Defines a mode for weight compression.
- INT8 stands for 8-bit integer quantization of all weights.
+ INT8_SYM stands for 8-bit integer symmetric quantization of all weights.
+ Weights are quantized symmetrically with a fixed zero point equals to 128.
+ INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
with a typical non-fixed zero point.
NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
:param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
- and the rest to INT8).
+ and the rest to INT8_ASYM).
:param group_size: number of weights (e.g. 128) in the channel dimension
that share quantization parameters (scale). The value -1 means no grouping.
:param ignored_scope: An ignored scope that defined the list of model control
diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py
index 8f00fdca516..4577fe8cb1c 100644
--- a/nncf/quantization/algorithms/weight_compression/backend.py
+++ b/nncf/quantization/algorithms/weight_compression/backend.py
@@ -47,10 +47,13 @@ def validate_params(mode: CompressWeightsMode, ignored_scope: Optional[IgnoredSc
parameters. Should be called on early algorithm steps to prevent execution of time-consuming operations.
:param mode: Defines a mode for weight compression.
- INT8 stands for 8-bit integer quantization of all weights.
+ INT8_SYM stands for 8-bit integer symmetric quantization of all weights.
+ Weights are quantized symmetrically with a fixed zero point equals to 128.
+ INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
@@ -77,17 +80,20 @@ def do_compression(
:param nodes_to_compress: List of nodes in the model's graph,
corresponding to the layers for weight compression.
:param mode: Defines a mode for weight compression.
- INT8 stands for 8-bit integer quantization of all weights.
+ INT8_SYM stands for 8-bit integer symmetric quantization of all weights.
+ Weights are quantized symmetrically with a fixed zero point equals to 128.
+ INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
with a typical non-fixed zero point.
NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
:param ratio: The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
- and the rest to INT8).
+ and the rest to INT8_ASYM).
:param group_size: Number of weights (e.g. 128) in the channel dimension
that share quantization parameters (scale). The value -1 means no grouping.
:return: A resulting model with compressed weights.
diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py
index 2738a7a38de..a361ba7501c 100644
--- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py
+++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py
@@ -102,13 +102,8 @@ def do_compression(
all_weight_params.append(weight_params)
quantized_nodes_ids.add(id(weight_node))
- internal_weight_params = all_weight_params
- if mode != CompressWeightsMode.INT8:
- internal_weight_params = list(filter(lambda wp: wp.metatype != OVEmbeddingMetatype, all_weight_params))
- if not is_last_layer_compressed:
- internal_weight_params = internal_weight_params[:-1]
- primary_config = WeightCompressionConfig(mode=mode, group_size=group_size)
- _assign_mixed_precision(internal_weight_params, ratio, primary_config)
+ internal_weight_params = _get_internal_weight_params(all_weight_params, mode, is_last_layer_compressed)
+ _set_weight_compression_config(internal_weight_params, mode, ratio, group_size)
nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params, internal_weight_params))
for wp in track(all_weight_params, description="Applying Weight Compression"):
@@ -121,28 +116,25 @@ def do_compression(
weight = get_const_value(weight_node)
config = wp.compression_config
+ original_shape = weight.shape
if config.mode == CompressWeightsMode.NF4:
- original_shape = weight.shape
norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axis, group_size)
compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name)
convert = opset.convert(compressed_const, original_weight_dtype)
mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name)
- if config.group_size != -1:
- mul = opset.reshape(mul, output_shape=original_shape, special_zero=False)
- last_output = mul.output(0)
else:
- original_shape = weight.shape
compressed_weights, scale, zero_point = _do_integer_quantization(weight, wp.reduction_axis, config)
- compression_type = np.uint8 if config.num_bits == 8 else ov.Type.u4
+ compression_type = ov.Type.u8 if config.num_bits == 8 else ov.Type.u4
compressed_weights_node = opset.constant(compressed_weights, dtype=compression_type, name=weight_name)
convert_weights_node = opset.convert(compressed_weights_node, original_weight_dtype)
zero_point_node = opset.constant(zero_point, dtype=compression_type, name=f"{weight_name}/ZP")
convert_zp_node = opset.convert(zero_point_node, original_weight_dtype)
sub = opset.subtract(convert_weights_node, convert_zp_node)
mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=wp.fq_name)
- if config.group_size != -1:
- mul = opset.reshape(mul, output_shape=original_shape, special_zero=False)
- last_output = mul.output(0)
+
+ if config.group_size != -1:
+ mul = opset.reshape(mul, output_shape=original_shape, special_zero=False)
+ last_output = mul.output(0)
for target_input in target_inputs:
target_input.replace_source_output(last_output)
@@ -167,12 +159,12 @@ class WeightCompressionConfig:
"""
Information on how to compress (quantize) a specific weight.
- :param mode: Defines a mode for weight compression. Defaults to INT8 mode.
+ :param mode: Defines a mode for weight compression. Defaults to INT8_ASYM mode.
:param group_size: Number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
The value -1 means no grouping. Defaults to -1.
"""
- mode: Optional[CompressWeightsMode] = CompressWeightsMode.INT8
+ mode: Optional[CompressWeightsMode] = CompressWeightsMode.INT8_ASYM
group_size: Optional[int] = -1
@property
@@ -180,7 +172,7 @@ def num_bits(self):
"""
:return: number of bits that is used for storing a single quantized value in the given mode.
"""
- return 8 if self.mode == CompressWeightsMode.INT8 else 4
+ return 8 if self.mode in [CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM] else 4
@dataclass
@@ -212,7 +204,10 @@ def _do_integer_quantization(
"""
The method quantizes the given weights to integer data type in accordance with the compression config.
The config defines a quantization mode:
- INT8 mode refers to unsigned int8 asymmetric weight compression - quantization to [0, 255] range.
+ INT8_SYM mode refers to unsigned int8 symmetric weight compression with a fixed zero point equals to 128 -
+ quantization to [0, 255] range.
+ INT8_ASYM mode refers to unsigned int8 asymmetric weight compression with a typical non-fixed zero-point -
+ quantization to [0, 255] range.
INT4_ASYM mode refers to unsigned int4 asymmetric weight compression with a typical non-fixed zero-point -
quantization to [0, 15] range.
INT4_SYM mode refers to unsigned int4 symmetric weight compression with a fixed zero point equals to 8 -
@@ -239,7 +234,7 @@ def _do_integer_quantization(
# weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2]
weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axis, group_size)
- if mode in [CompressWeightsMode.INT8, CompressWeightsMode.INT4_ASYM]:
+ if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]:
min_values = np.min(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = np.max(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
scale, zero_point = calculate_scale_zero_point(
@@ -349,21 +344,19 @@ def _get_bitwidth_distribution_str(all_params: List[WeightNodeParams], internal_
:param internal_params: List of information about weight nodes that are considered for mixed precision.
:return: A string containing the table.
"""
- not_internal_params = [wp for wp in all_params if wp not in internal_params]
num_bits_vs_num_weights_map = {}
- for data in internal_params:
- num_bits = data.compression_config.num_bits
- n_internal, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], []))
- n_internal.append(data.num_weights)
- num_bits_vs_num_weights_map[num_bits] = (n_internal, n_internal)
- for data in not_internal_params:
+ internal_fq_names = set(wp.fq_name for wp in internal_params)
+ for data in all_params:
num_bits = data.compression_config.num_bits
n_total, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], []))
+ if data.fq_name in internal_fq_names:
+ n_internal.append(data.num_weights)
n_total.append(data.num_weights)
num_bits_vs_num_weights_map[num_bits] = (n_total, n_internal)
+
num_internal_weights = sum(ws.num_weights for ws in internal_params)
num_internal_params = len(internal_params)
- total_num_weights = num_internal_weights + sum(ws.num_weights for ws in not_internal_params)
+ num_total_weights = sum(ws.num_weights for ws in all_params)
num_params = len(all_params)
num_bits_vs_num_weights_map = OrderedDict(sorted(num_bits_vs_num_weights_map.items(), reverse=True))
# Table creation
@@ -373,7 +366,7 @@ def _get_bitwidth_distribution_str(all_params: List[WeightNodeParams], internal_
rows.append(
[
bitwidth,
- _proportion_str(n_total, total_num_weights, num_params),
+ _proportion_str(n_total, num_total_weights, num_params),
_proportion_str(n_internal, num_internal_weights, num_internal_params),
]
)
@@ -383,6 +376,25 @@ def _get_bitwidth_distribution_str(all_params: List[WeightNodeParams], internal_
return pretty_string
+def _get_internal_weight_params(
+ all_weight_params: List[WeightNodeParams], mode: CompressWeightsMode, is_last_layer_compressed: bool
+) -> List[WeightNodeParams]:
+ """
+ Returns the internal weight parameters.
+
+ :param all_weight_params: List of all weight parameters.
+ :param mode: Weight compression mode.
+ :param is_last_layer_compressed: Indicates whether the last layer is compressed.
+ :return: List of information about weight nodes that are considered for mixed precision.
+ """
+ internal_weight_params = all_weight_params
+ if mode not in [CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM]:
+ internal_weight_params = list(filter(lambda wp: wp.metatype != OVEmbeddingMetatype, internal_weight_params))
+ if not is_last_layer_compressed:
+ internal_weight_params = internal_weight_params[:-1]
+ return internal_weight_params
+
+
def _assign_mixed_precision(
internal_weight_params: List[WeightNodeParams], ratio: float, primary_config: WeightCompressionConfig
) -> None:
@@ -391,14 +403,10 @@ def _assign_mixed_precision(
:param internal_weight_params: List of information about internal weight nodes. Only internal nodes are considered
for mixed precision. The quantization scheme is added to this info.
:param ratio: The ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
- and the rest to INT8).
+ and the rest to INT8_ASYM).
:param primary_config: Information on how to compress (quantize) weights to primary precision.
:return: None.
"""
- if ratio == 1:
- for weight_param in internal_weight_params:
- weight_param.compression_config = primary_config
- return
errors = []
num_internal_weights = 0
for weight_param in track(internal_weight_params, description="Searching for Mixed-Precision Configuration"):
@@ -421,3 +429,23 @@ def _assign_mixed_precision(
break
weight_param.compression_config = primary_config
num_weights_in_4bit += weight_param.num_weights
+
+
+def _set_weight_compression_config(
+ internal_weight_params: List[WeightNodeParams], mode: CompressWeightsMode, ratio: float, group_size: int
+) -> None:
+ """
+ Set the appropriate compression configuration for weights based on some criteria.
+
+ :param internal_weight_params: List of information about internal weight nodes.
+ :param mode: Weight compression mode.
+ :param ratio: The ratio between primary and backup precisions.
+ :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
+ :return: None.
+ """
+ primary_config = WeightCompressionConfig(mode=mode, group_size=group_size)
+ if ratio == 1:
+ for weight_param in internal_weight_params:
+ weight_param.compression_config = primary_config
+ else:
+ _assign_mixed_precision(internal_weight_params, ratio, primary_config)
diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py
index 6311ebdfe4c..2516b9e0913 100644
--- a/nncf/quantization/quantize_model.py
+++ b/nncf/quantization/quantize_model.py
@@ -12,6 +12,7 @@
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from nncf.api.compression import TModel
+from nncf.common.deprecation import warning_deprecated
from nncf.common.factory import NNCFGraphFactory
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.utils.api_marker import api
@@ -241,7 +242,7 @@ def quantize_with_accuracy_control(
@api(canonical_alias="nncf.compress_weights")
def compress_weights(
model: TModel,
- mode=CompressWeightsMode.INT8,
+ mode=CompressWeightsMode.INT8_ASYM,
ratio: Optional[float] = None,
group_size: Optional[int] = None,
ignored_scope: Optional[IgnoredScope] = None,
@@ -251,17 +252,19 @@ def compress_weights(
:param model: A model to be compressed.
:param mode: Defines a mode for weight compression.
- INT8 stands for 8-bit integer quantization of all weights.
+ INT8_SYM stands for 8-bit integer symmetric quantization of all weights.
+ INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
with a typical non-fixed zero point.
NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
:param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
- and the rest to INT8).
+ and the rest to INT8_ASYM).
:param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
The value -1 means no grouping.
:param ignored_scope: An ignored scope that defined the list of model control
@@ -269,6 +272,12 @@ def compress_weights(
:return: The non-trainable model with compressed weights.
"""
if mode == CompressWeightsMode.INT8:
+ warning_deprecated(
+ "`CompressWeightsMode.INT8` is deprecated." "Please, use `CompressWeightsMode.INT8_ASYM` as value instead."
+ )
+ mode = CompressWeightsMode.INT8_ASYM
+
+ if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
if ratio is None:
ratio = 1
if group_size is None:
diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py
index 9ae6496091d..91487604199 100644
--- a/nncf/torch/quantization/quantize_model.py
+++ b/nncf/torch/quantization/quantize_model.py
@@ -74,7 +74,7 @@ def quantize_impl(
def compress_weights_impl(
model: torch.nn.Module,
- mode=CompressWeightsMode.INT8,
+ mode=CompressWeightsMode.INT8_ASYM,
ratio: Optional[float] = None,
group_size: Optional[int] = None,
ignored_scope: Optional[IgnoredScope] = None,
@@ -85,17 +85,20 @@ def compress_weights_impl(
:param model: a Torch model for compression.
:param mode: Defines a mode for weight compression.
- INT8 stands for 8-bit integer quantization of all weights.
+ INT8_SYM stands for 8-bit integer symmetric quantization of all weights.
+ Weights are quantized symmetrically with a fixed zero point equals to 128.
+ INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically
+ with a typical non-fixed zero point.
INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision.
Weights are quantized to a primary precision symmetrically with a fixed zero point equals to 8.
- All embeddings and the last layer are always compressed to a backup precision, which is 8-bit integer,
+ All embeddings and the last layer are always compressed to a backup precision, which is INT8_ASYM,
by default. All others are quantized whether to 4-bit integer or to a backup precision depending on
criteria and the given ratio.
INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically
with a typical non-fixed zero point.
NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
:param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
- and the rest to INT8).
+ and the rest to INT8_ASYM).
:param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
The value -1 means no grouping.
:param ignored_scope: An ignored scope that defined the list of model control
@@ -104,8 +107,10 @@ def compress_weights_impl(
"""
if ignored_scope is not None:
raise AttributeError("Torch backend does not support ignored scope.")
- if mode != CompressWeightsMode.INT8:
- raise AttributeError(f"Torch backend supports only INT8 mode for weight compression, but given {mode} mode.")
+ if mode != CompressWeightsMode.INT8_ASYM:
+ raise AttributeError(
+ f"Torch backend supports only INT8_ASYM mode for weight compression, but given {mode} mode."
+ )
compressed_model, _ = replace_modules_by_nncf_modules(model)
insert_pre_compression_operations(model)
diff --git a/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8.json b/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8_asym.json
similarity index 100%
rename from tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8.json
rename to tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8_asym.json
diff --git a/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8_sym.json b/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8_sym.json
new file mode 100644
index 00000000000..41b80d9aa5e
--- /dev/null
+++ b/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_int8_sym.json
@@ -0,0 +1,200 @@
+{
+ "matmul_2_data": {
+ "compressed_weight": [
+ [
+ 182,
+ 152,
+ 200,
+ 255,
+ 165,
+ 136,
+ 193
+ ],
+ [
+ 155,
+ 140,
+ 206,
+ 168,
+ 219,
+ 155,
+ 255
+ ],
+ [
+ 177,
+ 142,
+ 212,
+ 251,
+ 187,
+ 255,
+ 195
+ ],
+ [
+ 182,
+ 207,
+ 255,
+ 249,
+ 187,
+ 225,
+ 191
+ ],
+ [
+ 200,
+ 235,
+ 184,
+ 228,
+ 225,
+ 255,
+ 144
+ ],
+ [
+ 222,
+ 248,
+ 253,
+ 130,
+ 240,
+ 255,
+ 252
+ ]
+ ],
+ "zero_point": [
+ 128
+ ],
+ "scale": [
+ [
+ 0.006270269863307476
+ ],
+ [
+ 0.007418213412165642
+ ],
+ [
+ 0.007516460493206978
+ ],
+ [
+ 0.007835405878722668
+ ],
+ [
+ 0.007339052855968475
+ ],
+ [
+ 0.007725945208221674
+ ]
+ ]
+ },
+ "matmul_1_data": {
+ "compressed_weight": [
+ [
+ 185,
+ 208,
+ 133,
+ 152,
+ 255,
+ 251
+ ],
+ [
+ 206,
+ 177,
+ 255,
+ 253,
+ 215,
+ 211
+ ],
+ [
+ 249,
+ 196,
+ 152,
+ 255,
+ 220,
+ 183
+ ],
+ [
+ 194,
+ 249,
+ 255,
+ 177,
+ 206,
+ 172
+ ],
+ [
+ 213,
+ 176,
+ 184,
+ 255,
+ 160,
+ 217
+ ],
+ [
+ 140,
+ 249,
+ 242,
+ 163,
+ 255,
+ 136
+ ]
+ ],
+ "zero_point": [
+ 128
+ ],
+ "scale": [
+ [
+ 0.0052805072627961636
+ ],
+ [
+ 0.007852046750485897
+ ],
+ [
+ 0.005681010894477367
+ ],
+ [
+ 0.0073546734638512135
+ ],
+ [
+ 0.0070100342854857445
+ ],
+ [
+ 0.006901450455188751
+ ]
+ ]
+ },
+ "gather_2_data": {
+ "compressed_weight": [
+ [
+ 217,
+ 166,
+ 134,
+ 130,
+ 241,
+ 255
+ ],
+ [
+ 210,
+ 227,
+ 202,
+ 255,
+ 239,
+ 128
+ ],
+ [
+ 254,
+ 133,
+ 235,
+ 154,
+ 255,
+ 208
+ ]
+ ],
+ "zero_point": [
+ 128
+ ],
+ "scale": [
+ [
+ 0.007187051698565483
+ ],
+ [
+ 0.0073627750389277935
+ ],
+ [
+ 0.006796684116125107
+ ]
+ ]
+ }
+}
\ No newline at end of file
diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py
index ab154a3453a..a1e129d3b35 100644
--- a/tests/openvino/native/quantization/test_weights_compression.py
+++ b/tests/openvino/native/quantization/test_weights_compression.py
@@ -47,7 +47,7 @@ def get_next_node(node):
return next_node
-def check_int8_node(op: ov.Node):
+def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode.INT8_ASYM):
assert op.get_element_type() == ov.Type(np.uint8)
compressed_weight = get_const_value(op)
@@ -62,6 +62,12 @@ def check_int8_node(op: ov.Node):
zero_point_node = convert_node.input_value(0).get_node()
zero_point = get_const_value(zero_point_node)
+ if mode == CompressWeightsMode.INT8_SYM:
+ assert list(zero_point_node.shape) == [1]
+ else:
+ reduced_weight_shape = list(op.shape)
+ reduced_weight_shape[-1] = 1
+ assert list(zero_point_node.shape) == reduced_weight_shape
mul_node = get_next_node(sub_node)
assert mul_node.get_type_name() == "Multiply"
@@ -144,6 +150,10 @@ def check_int4_asym_grouped(op: ov.Node):
return check_int4_grouped(op, mode=CompressWeightsMode.INT4_ASYM)
+def check_int8_sym(op: ov.Node):
+ return check_int8_node(op, mode=CompressWeightsMode.INT8_SYM)
+
+
def get_mixed_mapping(primary_fn: Callable, list_layers: List[str]):
mapping = {node_name: check_int8_node for node_name in list_layers}
primary_node_name = TEST_MODELS[IntegerModel][0]
@@ -154,7 +164,8 @@ def get_mixed_mapping(primary_fn: Callable, list_layers: List[str]):
@pytest.mark.parametrize(
("mode", "group_size", "check_fn_per_node_map"),
(
- (CompressWeightsMode.INT8, -1, {node_name: check_int8_node for node_name in TEST_MODELS[IntegerModel]}),
+ (CompressWeightsMode.INT8_ASYM, -1, {node_name: check_int8_node for node_name in TEST_MODELS[IntegerModel]}),
+ (CompressWeightsMode.INT8_SYM, -1, {node_name: check_int8_sym for node_name in TEST_MODELS[IntegerModel]}),
(CompressWeightsMode.INT4_SYM, 7, get_mixed_mapping(check_int4_sym_grouped, TEST_MODELS[IntegerModel])),
(CompressWeightsMode.INT4_ASYM, 7, get_mixed_mapping(check_int4_asym_grouped, TEST_MODELS[IntegerModel])),
(CompressWeightsMode.NF4, 7, get_mixed_mapping(check_nf4_grouped, TEST_MODELS[IntegerModel])),
@@ -197,9 +208,10 @@ def test_mixed_precision(ratio, group_size, ref_nf4_nodes):
assert op.get_element_type() == ov.Type.nf4
-def test_not_quantize_with_multiple_reduction_axes():
+@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM))
+def test_not_quantize_with_multiple_reduction_axes(mode):
model = GatherWithTwoReductionAxes().ov_model
- compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8)
+ compressed_model = compress_weights(model, mode=mode)
for op in compressed_model.get_ordered_ops():
if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_1_data":
assert op.get_element_type() == ov.Type(np.float32)
@@ -408,11 +420,13 @@ def test_raise_error_with_tuple():
_reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axis=(0,), group_size=3)
-def test_raise_error_with_int8_and_non_default_ratio(mocker):
+@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM))
+def test_raise_error_with_int8_and_non_default_ratio(mocker, mode):
with pytest.raises(AttributeError):
- compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, ratio=0.5)
+ compress_weights(mocker.Mock(), mode=mode, ratio=0.5)
-def test_raise_error_with_int8_and_non_default_group_size(mocker):
+@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM))
+def test_raise_error_with_int8_and_non_default_group_size(mocker, mode):
with pytest.raises(AttributeError):
- compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, group_size=64)
+ compress_weights(mocker.Mock(), mode=mode, group_size=64)
diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py
index 770664eaa9a..5a36c649ffd 100644
--- a/tests/torch/ptq/test_weights_compression.py
+++ b/tests/torch/ptq/test_weights_compression.py
@@ -74,18 +74,32 @@ def test_compress_shared_weights():
assert compressed_model.lm_head.get_pre_op(key) is val
-def test_raise_error_with_int8_and_non_default_ratio(mocker):
+@pytest.mark.parametrize(
+ "mode", [CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]
+)
+def test_raise_error_with_int8_and_non_default_ratio(mocker, mode):
with pytest.raises(AttributeError):
- compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, ratio=0.5)
+ compress_weights(mocker.Mock(), mode=mode, ratio=0.5)
-def test_raise_error_with_int8_and_non_default_group_size(mocker):
+@pytest.mark.parametrize(
+ "mode", [CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]
+)
+def test_raise_error_with_int8_and_non_default_group_size(mocker, mode):
with pytest.raises(AttributeError):
- compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, group_size=64)
-
-
-@pytest.mark.parametrize("mode", [CompressWeightsMode.NF4, CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM])
-def test_raise_error_with_not_int8(mode):
+ compress_weights(mocker.Mock(), mode=mode, group_size=64)
+
+
+@pytest.mark.parametrize(
+ "mode",
+ [
+ CompressWeightsMode.NF4,
+ CompressWeightsMode.INT4_ASYM,
+ CompressWeightsMode.INT4_SYM,
+ CompressWeightsMode.INT8_SYM,
+ ],
+)
+def test_raise_error_with_not_int8_asym(mode):
with pytest.raises(AttributeError):
dummy_torch_model = torch.nn.Module()
compress_weights(dummy_torch_model, mode=mode)