Skip to content

Commit

Permalink
Quantization / HQQ: Fix HQQ tests on our runner (#30668)
Browse files Browse the repository at this point in the history
Update test_hqq.py
  • Loading branch information
younesbelkada authored May 6, 2024
1 parent a45c514 commit 9c772ac
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/quantization/hqq/test_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=compute_dtype,
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_fp16_quantized_model(self):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)

def test_bfp16_quantized_model_with_offloading(self):
def test_f16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
Expand All @@ -137,7 +137,7 @@ def test_bfp16_quantized_model_with_offloading(self):
)

hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)

check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
Expand Down

0 comments on commit 9c772ac

Please sign in to comment.