Skip to content

Commit

Permalink
Scale for covering full range of compressed weights. (#2790)
Browse files Browse the repository at this point in the history
### Changes

Implemented negative scale for symmetric compression to cover cases when
positive weight has more wide range of values and now we can provide 8
bits for this case instead of 7.

### Reason for changes

Improvement in data-free accuracy.

### Related tickets

CVS-146491

### Tests

See changed files
  • Loading branch information
andreyanufr authored Jul 30, 2024
1 parent d94b93b commit d1e4ec5
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 144 deletions.
4 changes: 2 additions & 2 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ class AdvancedScaleEstimationParameters:
:type weight_penalty: float
"""

subset_size: int = 32
subset_size: int = 64
initial_steps: int = 5
scale_steps: int = 10
scale_steps: int = 5
weight_penalty: float = -1.0


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def apply(
"compress_model": compress_model,
}

scale_sign = scale / fns.abs(scale)
zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

Expand All @@ -229,6 +230,7 @@ def apply(
# iterative rectification of initial scale
for i in range(self._initial_steps):
near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance)
near_to_ideal_scale = near_to_ideal_scale * scale_sign
input_tensors[1] = near_to_ideal_scale.data

out = compress_decompress_model(input_tensors)
Expand Down Expand Up @@ -274,6 +276,7 @@ def apply(
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance)
near_to_ideal_scale = near_to_ideal_scale * scale_sign

input_tensors[1] = near_to_ideal_scale.data
out = compress_decompress_model(input_tensors)
Expand Down
33 changes: 27 additions & 6 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,31 @@ def calculate_e2m1_scale(weight: Tensor, reduction_axes: ReductionAxes, max_val=
return scale


def calculate_signed_scale(weight: Tensor, reduction_axes: ReductionAxes, num_bits=4) -> Tensor:
"""
Calculates the signed scale for symmetric quantization.
:param weight: Weight array to compress.
:param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max).
:param num_bits: number of bits in compression.
:return: Scale tensor.
"""
level_high = 2 ** (num_bits - 1)
scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2]

w_min = fns.abs(fns.min(weight, axis=reduction_axes, keepdims=True))
w_max = fns.abs(fns.max(weight, axis=reduction_axes, keepdims=True))

denum = fns.ones_like(scale) * level_high
denum = fns.where(w_min < w_max, -denum, denum)

scale /= denum
eps = fns.finfo(scale).eps
scale = fns.where(fns.abs(scale) < eps, eps, scale)

return scale


def calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor:
"""
Normalizes the weight tensor using the provided scale.
Expand Down Expand Up @@ -263,12 +288,7 @@ def calculate_integer_quantization_params(
)
return scale, zero_point

level_high = 2 ** (num_bits - 1) - 1
scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2]
scale /= level_high
eps = fns.finfo(scale).eps
# NOTE: adding machine epsilon to avoid division by zero
scale = fns.where(fns.abs(scale) < eps, eps, scale)
scale = calculate_signed_scale(weight, reduction_axes, num_bits)
return scale, None


Expand Down Expand Up @@ -309,6 +329,7 @@ def calculate_quantized_weight(
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype)

return compressed_weights


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@
"scale": [
[
[
0.11376953125
-0.09954833984375
]
],
[
[
0.1346435546875
-0.11773681640625
]
],
[
[
0.1363525390625
-0.11932373046875
]
],
[
[
0.1422119140625
-0.1243896484375
]
],
[
[
0.1331787109375
-0.11651611328125
]
],
[
[
0.14013671875
-0.12261962890625
]
]
]
Expand Down
Loading

0 comments on commit d1e4ec5

Please sign in to comment.