Skip to content

Commit

Permalink
Represent symmetrically quantized weights in signed data type
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Jan 29, 2024
1 parent 4c01098 commit e892fec
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 203 deletions.
4 changes: 2 additions & 2 deletions docs/compression_algorithms/CompressWeights.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The Weights Compression algorithm is aimed at compressing the weights of the mod
#### Supported modes

By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode.
OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is unsigned 4-bit integer and weights are quantized to it [symmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) with a fixed zero point equals to 8. In case of INT4_ASYM mode - also unsigned 4-bit integer, but weight are quantized to it [asymmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point.
OpenVINO backend also supports 3 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM and NF4. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it [symmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer, but weight are quantized to it [asymmetrically](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization) with a typical non-fixed zero point. In case of NF4 mode - [nf4](https://arxiv.org/pdf/2305.14314v1.pdf) data type without zero point.
All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale).
All embeddings and last linear layers are always compressed to 8-bit integer data type.
Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type.
Expand Down Expand Up @@ -285,7 +285,7 @@ Here is the word perplexity with data-free and data-aware mixed-precision INT4-I
- The algorithm is supported for OpenVINO and PyTorch models.
- The compression applies in-place.
- The compressed model is not trainable.
- INT8_SYM, INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
- INT4_SYM, INT4_ASYM and NF4 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
- NF4 support is experimental - models quantized to nf4 should not be faster models quantized to 8-bit integer.

#### Additional resources
Expand Down
38 changes: 20 additions & 18 deletions nncf/quantization/algorithms/weight_compression/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor.definitions import TensorDataType
from nncf.experimental.tensor.functions import count_nonzero
from nncf.experimental.tensor.tensor import Tensor
from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
Expand Down Expand Up @@ -96,17 +98,14 @@ def transform_model(
compression_config = wc_params.compression_config
if compression_config.mode == CompressWeightsMode.NF4:
compression_dtype = ov.Type.nf4
elif compression_config.mode in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8,
CompressWeightsMode.INT4_ASYM,
CompressWeightsMode.INT4_SYM,
]:
if compression_config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM]:
compression_dtype = ov.Type.u4
else:
compression_dtype = ov.Type.u8
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
compression_dtype = ov.Type.i4
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
compression_dtype = ov.Type.u4
elif compression_config.mode == CompressWeightsMode.INT8_SYM:
compression_dtype = ov.Type.i8
elif compression_config.mode == CompressWeightsMode.INT8_ASYM:
compression_dtype = ov.Type.u8
else:
raise ValueError(f"{compression_config.mode.value} is not supported.")

Expand All @@ -124,13 +123,16 @@ def transform_model(
)
converted_const = opset.convert(compressed_const, const_dtype)
if compressed_weight.zero_point is not None:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, const_dtype)
converted_const = opset.subtract(converted_const, converted_zero_point)
if compressed_weight.tensor.dtype == TensorDataType.int8:
assert count_nonzero(compressed_weight.zero_point.data) == 0
else:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, const_dtype)
converted_const = opset.subtract(converted_const, converted_zero_point)

scale_data = compressed_weight.scale.data
mul = opset.multiply(
Expand Down
20 changes: 14 additions & 6 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor.definitions import TensorDataType
from nncf.experimental.tensor.functions import count_nonzero
from nncf.experimental.tensor.tensor import Tensor
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
Expand All @@ -34,7 +35,8 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import WeightsDecompressor
from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import SymmetricWeightsDecompressor
from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector


Expand Down Expand Up @@ -210,7 +212,11 @@ def transform_model(
compressed_weight = compress_weight(Tensor(weight), wc_params.reduction_axis, compression_config)

# pack compressed tensor
packed_tensor = compressed_weight.tensor.astype(TensorDataType.uint8)
if compression_config.mode == CompressWeightsMode.INT8_SYM:
dtype = TensorDataType.int8
else:
dtype = TensorDataType.uint8
packed_tensor = compressed_weight.tensor.astype(dtype)

# sets compressed tensor
compressed_parameter = torch.nn.Parameter(packed_tensor.data, requires_grad=False)
Expand All @@ -224,11 +230,13 @@ def transform_model(
if id(param) == id(weight):
setattr(c_module, name, compressed_parameter)

# pack zero point tensor
packed_zero_point = compressed_weight.zero_point.astype(TensorDataType.uint8)

# creates weight decompressor
decompressor = WeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data)
if compression_config.mode == CompressWeightsMode.INT8_SYM:
assert count_nonzero(compressed_weight.zero_point) == 0
decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data)
else:
packed_zero_point = compressed_weight.zero_point.astype(dtype)
decompressor = AsymmetricWeightsDecompressor(compressed_weight.scale.data, packed_zero_point.data)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
Expand Down
31 changes: 16 additions & 15 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def do_integer_quantization(
"""
The method quantizes the given weights to integer data type in accordance with the compression config.
The config defines a quantization mode:
INT8_SYM mode refers to unsigned int8 symmetric weight compression with a fixed zero point equals to 128 -
quantization to [0, 255] range.
INT8_SYM mode refers to signed int8 symmetric weight compression without zero point -
quantization to [-128, 127] range.
INT8_ASYM mode refers to unsigned int8 asymmetric weight compression with a typical non-fixed zero-point -
quantization to [0, 255] range.
INT4_ASYM mode refers to unsigned int4 asymmetric weight compression with a typical non-fixed zero-point -
quantization to [0, 15] range.
INT4_SYM mode refers to unsigned int4 symmetric weight compression with a fixed zero point equals to 8 -
quantization to [0, 15] range.
INT4_SYM mode refers to signed int4 symmetric weight compression without zero point -
quantization to [-8, 7] range.
NF4 mode requires a dedicated procedure and it is not supported in this method.
One of the parameter of compression config is a group size. Quantization is per-channel, if group size equals to -1,
otherwise it's per-group, i.e. group size number of weights in the channel dimension share quantization parameters
Expand All @@ -113,17 +113,14 @@ def do_integer_quantization(
:param weight: Weight array to compress.
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param config: Information on how to compress (quantize) a specific weight.
:return: The compressed weights tensor of uint8 type, scale tensor of float32 type and
zero point tensor of int32 type that was used for its quantization.
:return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type,
scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization.
"""
mode = config.mode
assert mode != CompressWeightsMode.NF4, "The function supports integer quantization only"
group_size = config.group_size
num_bits = config.num_bits

level_low = 0
level_high = 2**num_bits - 1

if weight.dtype != TensorDataType.float32:
weight = weight.astype(TensorDataType.float32)

Expand All @@ -132,23 +129,27 @@ def do_integer_quantization(
weight, reduction_axis = reshape_weight_for_grouped_quantization(weight, reduction_axis, group_size)

if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]:
level_low = 0
level_high = 2**num_bits - 1
min_values = fns.min(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = fns.max(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
scale, zero_point = calculate_scale_zero_point(
min_values, max_values, level_low, level_high, narrow_range=False
)
compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype))
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8)
else:
level_low = -(2 ** (num_bits - 1))
level_high = 2 ** (num_bits - 1) - 1
scale = fns.max(fns.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2]
level_low_sym = -(2 ** (num_bits - 1))
level_high_sym = 2 ** (num_bits - 1) - 1
scale = scale / level_high_sym
zero_point = fns.as_tensor_like(scale, [-level_low_sym])
scale = scale / level_high
zero_point = fns.zeros_like(scale)
eps = fns.finfo(scale).eps
# NOTE: adding machine epsilon to avoid division by zero
scale = fns.where(fns.abs(scale) < eps, eps, scale)
compressed_weights = fns.round(weight / scale)
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.int8)

compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype))
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8)
return compressed_weights, scale, zero_point


Expand Down
21 changes: 19 additions & 2 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,9 +1034,9 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool,
return get_per_channel_scale_shape(input_shape, is_weights, channel_idx)


class WeightsDecompressor(nn.Module):
class AsymmetricWeightsDecompressor(nn.Module):
"""
Applies decompression of compressed weights in the forward pass
Applies asymmetric decompression of compressed weights in the forward pass
"""

def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor):
Expand All @@ -1050,3 +1050,20 @@ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor):

def forward(self, x):
return decompress(x, self._scale, self._zero_point)


class SymmetricWeightsDecompressor(nn.Module):
"""
Applies symmetric decompression of compressed weights in the forward pass
"""

def __init__(self, scale: torch.Tensor):
"""
:param scale: A scale in quantization scheme
"""
super().__init__()
self.register_buffer("_scale", scale)

def forward(self, x):
zero_point = torch.zeros_like(self._scale)
return decompress(x, self._scale, zero_point)
Loading

0 comments on commit e892fec

Please sign in to comment.