Skip to content

Commit

Permalink
Refine HQQ UTs (#1888)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 2, 2024
1 parent 5592acc commit 63b2912
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 137 deletions.
130 changes: 0 additions & 130 deletions test/3x/torch/quantization/weight_only/hqq/test_hqq_cuda.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import transformers
from transformers import AutoModelForCausalLM

from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
from neural_compressor.torch.quantization import HQQConfig, convert, get_default_hqq_config, prepare, quantize
Expand All @@ -14,7 +15,9 @@
device = accelerator.current_device_name()


def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128):
def _common_hqq_test(
nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128, device=None
):
# Parse config
weight_qconfig = QTensorConfig(
nbits=nbits, channel_wise=True, group_size=group_size, optimize=True, round_zero=True if nbits == 4 else False
Expand All @@ -26,15 +29,14 @@ def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False,
if quant_scale:
scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False)
hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
device = "cpu"

# Create HQQ Linear
bs = 4
in_features = 64
out_features = 128
float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features)
if hqq_global_option.use_half:
print(f"hqq_global_option use half: {hqq_global_option.use_half}")
logger.info(f"hqq_global_option use half: {hqq_global_option.use_half}")
float_linear = float_linear.half()
float_linear.to(device)
float_linear_copy = deepcopy(float_linear)
Expand All @@ -54,7 +56,7 @@ def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False,
del float_output, hqq_output, hqq_output_2


class TestHQQCPU:
class TestHQQ:

@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -137,6 +139,7 @@ def test_quant_lm_head(self, force_use_cpu, force_not_half):
id(model.model.decoder.embed_tokens.weight) == lm_head_id
), "The tied lm_head weight is not deep copied, please check!"

@pytest.mark.parametrize("device_name", ["cuda", "cpu"])
@pytest.mark.parametrize(
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
[
Expand All @@ -155,13 +158,26 @@ def test_quant_lm_head(self, force_use_cpu, force_not_half):
(4, -1, False, True, 64),
],
)
def test_hqq_module_cpu(
self, force_use_cpu, force_not_half, nbits, group_size, quant_zero, quant_scale, scale_quant_group_size
def test_hqq_module(
self,
nbits,
group_size,
quant_zero,
quant_scale,
scale_quant_group_size,
device_name,
):
_common_cpu_test(
if device_name == "cuda" and not torch.cuda.is_available():
pytest.skip("Skipping CUDA test because cuda is not available")
if device_name == "cpu":
os.environ["FORCE_DEVICE"] = "cpu"
hqq_global_option.use_half = False

_common_hqq_test(
nbits=nbits,
group_size=group_size,
quant_zero=quant_zero,
quant_scale=quant_scale,
scale_quant_group_size=scale_quant_group_size,
device=torch.device(device_name),
)

0 comments on commit 63b2912

Please sign in to comment.