diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index b7d3f263b52..156335e43b5 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -10,7 +10,7 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypeVar from nncf import Dataset from nncf.common.graph.graph import NNCFGraph @@ -301,7 +301,15 @@ def apply( return res -def get_target_zero_mask(compressed_weights, zp=None): +def get_target_zero_mask(compressed_weights: TTensor, zp: Optional[TTensor] = None) -> Tuple[TTensor, TTensor]: + """ + Computes the target values and a mask indicating zero values in the target. + + :param compressed_weights: The compressed weights tensor. + :param zp: The zero point tensor. + :return: The compressed weights optionally adjusted by the zero point and + a boolean mask indicating positions in the target that are close to zero. + """ target = compressed_weights if zp is not None: target = target.astype(dtype=zp.dtype) - zp @@ -309,7 +317,16 @@ def get_target_zero_mask(compressed_weights, zp=None): return target, zero_mask -def estimate_scales(weight, target, zero_mask, importance): +def estimate_scales(weight: TTensor, target: TTensor, zero_mask: TTensor, importance: TTensor) -> TTensor: + """ + Estimates scales for the given weight, target, zero mask, and importance. + + :param weight: The weights tensor. + :param target: The target values tensor. + :param zero_mask: A boolean mask indicating positions in the target that are close to zero. + :param importance: The importance values tensor. + :return: The estimated scales + """ ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask) weighted_scale = ideal_scale * importance near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)