Skip to content

Commit

Permalink
added test_pack_uin4 and test_pack_in4
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Oct 24, 2024
1 parent e38788c commit 2632e99
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
from nncf.torch.quantization.quantize_functions import pack_int4
from nncf.torch.quantization.quantize_functions import pack_uint4
from nncf.torch.quantization.quantize_functions import unpack_int4
from nncf.torch.quantization.quantize_functions import unpack_uint4

DATA_BASED_SENSITIVITY_METRICS = (
SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
Expand Down Expand Up @@ -311,3 +315,21 @@ def test_model_devices_and_precisions(use_cuda, dtype):
assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16
# Result should be in the precision of the model
assert result.dtype == dtype


def test_pack_uint4():
w_uint8 = torch.randint(0, 15, (4, 4), dtype=torch.uint8)
packed_w = pack_uint4(w_uint8)
assert packed_w.dtype == torch.uint8
assert packed_w.numel() * 2 == w_uint8.numel()
unpacked_w = unpack_uint4(packed_w).reshape(w_uint8.shape)
assert torch.all(unpacked_w == w_uint8)


def test_pack_int4():
w_int8 = torch.randint(-8, 7, (4, 4), dtype=torch.int8)
packed_w = pack_int4(w_int8)
assert packed_w.dtype == torch.uint8
assert packed_w.numel() * 2 == w_int8.numel()
unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape)
assert torch.all(unpacked_w == w_int8)

0 comments on commit 2632e99

Please sign in to comment.