Skip to content

Commit

Permalink
replied to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Oct 24, 2024
1 parent 725600e commit 9a1432c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch

from nncf.common.logging import nncf_logger
from nncf.errors import ValidationError
from nncf.torch.dynamic_graph.patch_pytorch import register_operator
from nncf.torch.functions import STRound
from nncf.torch.functions import clamp
Expand Down Expand Up @@ -286,10 +287,10 @@ def pack_uint4(tensor: torch.Tensor) -> torch.Tensor:
:param tensor: A tensor of dtype `torch.uint8` where each element represents a uint4 value.
The tensor should contain values in the range [0, 15].
:return: A packed tensor of dtype `torch.uint8` where each element packs two uint4 values.
:raises ValueError: If the input tensor is not of type `torch.uint8`.
:raises nncf.errors.ValidationError: If the input tensor is not of type `torch.uint8`.
"""
if tensor.dtype != torch.uint8:
raise ValueError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.")
raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.")
packed_tensor = tensor.contiguous()
packed_tensor = packed_tensor.reshape(-1, 2)
packed_tensor = torch.bitwise_and(packed_tensor[..., ::2], 15) | packed_tensor[..., 1::2] << 4
Expand All @@ -299,7 +300,7 @@ def pack_uint4(tensor: torch.Tensor) -> torch.Tensor:
@register_operator()
def unpack_uint4(packed_tensor: torch.Tensor) -> torch.Tensor:
"""
Unpacks a tensor where each uint8 element stores two uint4 values back into a tensor with
Unpacks a tensor, where each uint8 element stores two uint4 values, back into a tensor with
individual uint4 values.
:param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two uint4 values.
Expand All @@ -316,18 +317,18 @@ def pack_int4(tensor: torch.Tensor) -> torch.Tensor:
:param tensor: A tensor of dtype `torch.int8` where each element represents an int4 value.
The tensor should contain values in the range [-8, 7].
:return: A packed tensor of dtype `torch.uint8` where each element packs two int4 values.
:raises ValueError: If the input tensor is not of type `torch.int8`.
:raises nncf.errors.ValidationError: If the input tensor is not of type `torch.int8`.
"""
if tensor.dtype != torch.int8:
raise ValueError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.")
raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.")
tensor = tensor + 8
return pack_uint4(tensor.type(torch.uint8))


@register_operator()
def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor:
"""
Unpacks a tensor where each uint8 element stores two int4 values back into a tensor with
Unpacks a tensor, where each uint8 element stores two int4 values, back into a tensor with
individual int4 values.
:param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two int4 values.
Expand Down

0 comments on commit 9a1432c

Please sign in to comment.