From 9c772ac88883dd35166b58cc8cf10cffa3ca7844 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 6 May 2024 11:33:52 +0200 Subject: [PATCH] Quantization / HQQ: Fix HQQ tests on our runner (#30668) Update test_hqq.py --- tests/quantization/hqq/test_hqq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index e4e01f86496388..45c64676a7e42a 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -35,7 +35,7 @@ class HQQLLMRunner: - def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): + def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None): self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=compute_dtype, @@ -118,7 +118,7 @@ def test_fp16_quantized_model(self): check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_forward(self, hqq_runner.model) - def test_bfp16_quantized_model_with_offloading(self): + def test_f16_quantized_model_with_offloading(self): """ Simple LLM model testing bfp16 with meta-data offloading """ @@ -137,7 +137,7 @@ def test_bfp16_quantized_model_with_offloading(self): ) hqq_runner = HQQLLMRunner( - model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device ) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)