Skip to content

Commit

Permalink
docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Jun 13, 2024
1 parent f81cc85 commit 3e6c649
Showing 1 changed file with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
Expand Down Expand Up @@ -301,15 +301,32 @@ def apply(
return res


def get_target_zero_mask(compressed_weights, zp=None):
def get_target_zero_mask(compressed_weights: TTensor, zp: Optional[TTensor] = None) -> Tuple[TTensor, TTensor]:
"""
Computes the target values and a mask indicating zero values in the target.
:param compressed_weights: The compressed weights tensor.
:param zp: The zero point tensor.
:return: The compressed weights optionally adjusted by the zero point and
a boolean mask indicating positions in the target that are close to zero.
"""
target = compressed_weights
if zp is not None:
target = target.astype(dtype=zp.dtype) - zp
zero_mask = fns.isclose(target, 0)
return target, zero_mask


def estimate_scales(weight, target, zero_mask, importance):
def estimate_scales(weight: TTensor, target: TTensor, zero_mask: TTensor, importance: TTensor) -> TTensor:
"""
Estimates scales for the given weight, target, zero mask, and importance.
:param weight: The weights tensor.
:param target: The target values tensor.
:param zero_mask: A boolean mask indicating positions in the target that are close to zero.
:param importance: The importance values tensor.
:return: The estimated scales
"""
ideal_scale = fns.abs(weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)
Expand Down

0 comments on commit 3e6c649

Please sign in to comment.