From 9a1432cf363a44ed717ba67765fa61e8c4400071 Mon Sep 17 00:00:00 2001 From: Aleksandr Suslov Date: Thu, 24 Oct 2024 16:19:29 +0400 Subject: [PATCH] replied to comments --- nncf/torch/quantization/quantize_functions.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index 239b1e97371..b967bb57683 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -13,6 +13,7 @@ import torch from nncf.common.logging import nncf_logger +from nncf.errors import ValidationError from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.functions import STRound from nncf.torch.functions import clamp @@ -286,10 +287,10 @@ def pack_uint4(tensor: torch.Tensor) -> torch.Tensor: :param tensor: A tensor of dtype `torch.uint8` where each element represents a uint4 value. The tensor should contain values in the range [0, 15]. :return: A packed tensor of dtype `torch.uint8` where each element packs two uint4 values. - :raises ValueError: If the input tensor is not of type `torch.uint8`. + :raises nncf.errors.ValidationError: If the input tensor is not of type `torch.uint8`. """ if tensor.dtype != torch.uint8: - raise ValueError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.") + raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.") packed_tensor = tensor.contiguous() packed_tensor = packed_tensor.reshape(-1, 2) packed_tensor = torch.bitwise_and(packed_tensor[..., ::2], 15) | packed_tensor[..., 1::2] << 4 @@ -299,7 +300,7 @@ def pack_uint4(tensor: torch.Tensor) -> torch.Tensor: @register_operator() def unpack_uint4(packed_tensor: torch.Tensor) -> torch.Tensor: """ - Unpacks a tensor where each uint8 element stores two uint4 values back into a tensor with + Unpacks a tensor, where each uint8 element stores two uint4 values, back into a tensor with individual uint4 values. :param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two uint4 values. @@ -316,10 +317,10 @@ def pack_int4(tensor: torch.Tensor) -> torch.Tensor: :param tensor: A tensor of dtype `torch.int8` where each element represents an int4 value. The tensor should contain values in the range [-8, 7]. :return: A packed tensor of dtype `torch.uint8` where each element packs two int4 values. - :raises ValueError: If the input tensor is not of type `torch.int8`. + :raises nncf.errors.ValidationError: If the input tensor is not of type `torch.int8`. """ if tensor.dtype != torch.int8: - raise ValueError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.") + raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.") tensor = tensor + 8 return pack_uint4(tensor.type(torch.uint8)) @@ -327,7 +328,7 @@ def pack_int4(tensor: torch.Tensor) -> torch.Tensor: @register_operator() def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor: """ - Unpacks a tensor where each uint8 element stores two int4 values back into a tensor with + Unpacks a tensor, where each uint8 element stores two int4 values, back into a tensor with individual int4 values. :param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two int4 values.