From 2552f1b4b87ba98bdc8a36aab9f8bb6ea0846afb Mon Sep 17 00:00:00 2001 From: andreyanufr Date: Fri, 5 Jul 2024 09:11:52 +0200 Subject: [PATCH] 1) Fixed bug with clamp range in scale estimation. (#2781) ### Changes Fixed bug with clamp ranges in scale estimation. ### Reason for changes bug ### Related tickets ### Tests Unit test. --- .../weight_compression/openvino_backend.py | 17 +++++----- .../weight_compression/scale_estimation.py | 6 ++-- .../weight_compression/weight_lowering.py | 18 +++++++++-- .../quantization/test_weights_compression.py | 31 +++++++++++++++++++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index f520345aa93..78c4eaedfbd 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -33,6 +33,7 @@ from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight from nncf.tensor import Tensor @@ -227,11 +228,9 @@ def dump_parameters( dump_parameters(model, parameters, algo_name, path) @staticmethod - def get_compress_decompress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None - ): + def get_compress_decompress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None): parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline( - weight_compression_parameter, w_shape, s_shape, z_p_shape, True + config, w_shape, s_shape, z_p_shape, True ) if len(parameters) == 3: @@ -248,16 +247,14 @@ def get_compress_decompress_pipeline( return lambda parameters: compiled_model(parameters)[0] @staticmethod - def get_compress_pipeline( - weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False - ): - config = weight_compression_parameter.compression_config + def get_compress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None, return_nodes=False): mode = config.mode assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM] num_bits = config.num_bits - level_low = 0 - level_high = 2**num_bits - 1 + asym_quant = mode in [CompressWeightsMode.INT4_ASYM] + level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) + level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 w = opset.parameter(w_shape, name="w") s = opset.parameter(s_shape, name="s") diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 885adbb02ef..7ac5eb2aef1 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -211,9 +211,11 @@ def apply( compress_model = compress_decompress_cache[key]["compress_model"] else: compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( - wp, q_weights.shape, scale.shape, zp_shape + wp.compression_config, q_weights.shape, scale.shape, zp_shape + ) + compress_model = self._backend_entity.get_compress_pipeline( + wp.compression_config, q_weights.shape, scale.shape, zp_shape ) - compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape) compress_decompress_cache[key] = { "compress_decompress_model": compress_decompress_model, "compress_model": compress_model, diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index ba11ebe2576..c94d014f868 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -273,7 +273,11 @@ def calculate_integer_quantization_params( def calculate_quantized_weight( - weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None + weight: Tensor, + config: WeightCompressionConfig, + scale: Tensor, + zero_point: Optional[Tensor] = None, + invert_scale=False, ) -> Tensor: """ Quantizes the weight tensor using the provided scale and zero point. @@ -282,6 +286,7 @@ def calculate_quantized_weight( :param config: Weight compression configuration. :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. + :param invert_scale: applies inversion for scale and then multiply by weights instead of division. :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: @@ -295,7 +300,11 @@ def calculate_quantized_weight( level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - compressed_weights = weight / scale + if invert_scale: + scale = fns.power(scale, -1) + compressed_weights = weight * scale + else: + compressed_weights = weight / scale if zero_point is not None: compressed_weights += zero_point.astype(weight.dtype) compressed_weights = fns.round(compressed_weights) @@ -309,6 +318,7 @@ def do_integer_quantization( config: WeightCompressionConfig, precomputed_scale: Tensor = None, precomputed_zero_point: Tensor = None, + invert_scale=False, ) -> Tuple[Tensor, Tensor, Tensor]: """ The method quantizes the given weights to integer data type in accordance with the compression config. @@ -331,6 +341,8 @@ def do_integer_quantization( :param config: Information on how to compress (quantize) a specific weight. :param precomputed_scale: Precomputed scale. :param precomputed_zero_point: Precomputed zero point. + :param invert_scale: applies inversion for scale and then multiply by weights instead of division. + Need as reference implementation for OV. :return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type, scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. """ @@ -351,7 +363,7 @@ def do_integer_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) + compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale) return compressed_weights, scale, zero_point diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 0b01674e43c..c9547d38493 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -28,6 +28,9 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA +from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.scopes import IgnoredScope @@ -912,3 +915,31 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids): } ref_e8m0_nodes = {f"weights_{i}/scale" for i in ref_ids} assert ref_e8m0_nodes == names_e8m0 + + +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) +def test_np_ov_compression_decompression(mode): + sz = 60 + w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0 + w = Tensor(w) + + config = WeightCompressionConfig(mode) + + compressed_weighs, scale, zp = do_integer_quantization(w, -1, config, invert_scale=True) + decompressed_weighs = do_dequantization(compressed_weighs, scale, zp) + + compressed_weighs = compressed_weighs.data + decompressed_weighs = decompressed_weighs.data + zp_shape = zp.shape if zp is not None else None + + compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape) + compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline( + config, w.shape, scale.shape, zp_shape + ) + + params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data] + compressed_weighs_ov = compress(params) + decompressed_weighs_ov = compress_decompress(params) + + assert np.allclose(compressed_weighs, compressed_weighs_ov) + assert np.allclose(decompressed_weighs, decompressed_weighs_ov)