Skip to content

Commit

Permalink
1) Fixed bug with clamp range in scale estimation. (#2781)
Browse files Browse the repository at this point in the history
### Changes

Fixed bug with clamp ranges in scale estimation.

### Reason for changes

bug

### Related tickets



### Tests

Unit test.
  • Loading branch information
andreyanufr authored Jul 5, 2024
1 parent d256991 commit 2552f1b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight
from nncf.tensor import Tensor
Expand Down Expand Up @@ -227,11 +228,9 @@ def dump_parameters(
dump_parameters(model, parameters, algo_name, path)

@staticmethod
def get_compress_decompress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None
):
def get_compress_decompress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None):
parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline(
weight_compression_parameter, w_shape, s_shape, z_p_shape, True
config, w_shape, s_shape, z_p_shape, True
)

if len(parameters) == 3:
Expand All @@ -248,16 +247,14 @@ def get_compress_decompress_pipeline(
return lambda parameters: compiled_model(parameters)[0]

@staticmethod
def get_compress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False
):
config = weight_compression_parameter.compression_config
def get_compress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None, return_nodes=False):
mode = config.mode
assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]
num_bits = config.num_bits

level_low = 0
level_high = 2**num_bits - 1
asym_quant = mode in [CompressWeightsMode.INT4_ASYM]
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

w = opset.parameter(w_shape, name="w")
s = opset.parameter(s_shape, name="s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ def apply(
compress_model = compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp_shape
wp.compression_config, q_weights.shape, scale.shape, zp_shape
)
compress_model = self._backend_entity.get_compress_pipeline(
wp.compression_config, q_weights.shape, scale.shape, zp_shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape)
compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
Expand Down
18 changes: 15 additions & 3 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ def calculate_integer_quantization_params(


def calculate_quantized_weight(
weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None
weight: Tensor,
config: WeightCompressionConfig,
scale: Tensor,
zero_point: Optional[Tensor] = None,
invert_scale=False,
) -> Tensor:
"""
Quantizes the weight tensor using the provided scale and zero point.
Expand All @@ -282,6 +286,7 @@ def calculate_quantized_weight(
:param config: Weight compression configuration.
:param scale: Scale tensor used for quantization.
:param zero_point: Zero point tensor used for quantization.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
:return: Quantized weight tensor of uint8 or int8 type.
"""
if weight.dtype != TensorDataType.float32:
Expand All @@ -295,7 +300,11 @@ def calculate_quantized_weight(
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

compressed_weights = weight / scale
if invert_scale:
scale = fns.power(scale, -1)
compressed_weights = weight * scale
else:
compressed_weights = weight / scale
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
Expand All @@ -309,6 +318,7 @@ def do_integer_quantization(
config: WeightCompressionConfig,
precomputed_scale: Tensor = None,
precomputed_zero_point: Tensor = None,
invert_scale=False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
The method quantizes the given weights to integer data type in accordance with the compression config.
Expand All @@ -331,6 +341,8 @@ def do_integer_quantization(
:param config: Information on how to compress (quantize) a specific weight.
:param precomputed_scale: Precomputed scale.
:param precomputed_zero_point: Precomputed zero point.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
Need as reference implementation for OV.
:return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type,
scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization.
"""
Expand All @@ -351,7 +363,7 @@ def do_integer_quantization(
if precomputed_zero_point is not None:
zero_point = precomputed_zero_point

compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point)
compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale)
return compressed_weights, scale, zero_point


Expand Down
31 changes: 31 additions & 0 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
from nncf.scopes import IgnoredScope
Expand Down Expand Up @@ -912,3 +915,31 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
}
ref_e8m0_nodes = {f"weights_{i}/scale" for i in ref_ids}
assert ref_e8m0_nodes == names_e8m0


@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM))
def test_np_ov_compression_decompression(mode):
sz = 60
w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0
w = Tensor(w)

config = WeightCompressionConfig(mode)

compressed_weighs, scale, zp = do_integer_quantization(w, -1, config, invert_scale=True)
decompressed_weighs = do_dequantization(compressed_weighs, scale, zp)

compressed_weighs = compressed_weighs.data
decompressed_weighs = decompressed_weighs.data
zp_shape = zp.shape if zp is not None else None

compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape)
compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline(
config, w.shape, scale.shape, zp_shape
)

params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data]
compressed_weighs_ov = compress(params)
decompressed_weighs_ov = compress_decompress(params)

assert np.allclose(compressed_weighs, compressed_weighs_ov)
assert np.allclose(decompressed_weighs, decompressed_weighs_ov)

0 comments on commit 2552f1b

Please sign in to comment.