From 7d4e1b85e78574acb5c682fa9fe1d3dfa5f092d7 Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Tue, 2 Apr 2024 07:32:01 +0800 Subject: [PATCH] [Misc] Add support for new autogptq checkpoint_format (#3689) Co-authored-by: Robert Shaw --- .../test_autogptq_marlin_configs.py | 68 +++++++++++++++++++ vllm/config.py | 28 ++++---- 2 files changed, 83 insertions(+), 13 deletions(-) create mode 100644 tests/quantization/test_autogptq_marlin_configs.py diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py new file mode 100644 index 0000000000000..cd64622e2226f --- /dev/null +++ b/tests/quantization/test_autogptq_marlin_configs.py @@ -0,0 +1,68 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Expected Kernel +MODELS_QUANT_TYPE = [ + # compat: autogptq <=0.7.1 is_marlin_format: bool + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") +] + + +@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) +def test_auto_gptq(model_quant_type: str, ) -> None: + model_path, quant_type = model_quant_type + + model_config_no_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization=None # case 1 + ) + + model_config_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization="gptq" # case 2 + ) + + assert model_config_no_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_no_quant_arg.quantization} " + "for no --quantization None case") + + assert model_config_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_quant_arg.quantization} " + "for --quantization gptq case") diff --git a/vllm/config.py b/vllm/config.py index 3da9abb13ad9a..903829d8b176d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -171,26 +171,28 @@ def _verify_quantization(self) -> None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - hf_quant_config = getattr(self.hf_config, "quantization_config", None) - if hf_quant_config is not None: - hf_quant_method = str(hf_quant_config["quant_method"]).lower() - - # If the GPTQ model is serialized in marlin format, use marlin. - if (hf_quant_method == "gptq" - and "is_marlin_format" in hf_quant_config - and hf_quant_config["is_marlin_format"]): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" + or quant_cfg.get("is_marlin_format", False)) + + # Use marlin if the GPTQ model is serialized in marlin format. + if quant_method == "gptq" and is_format_marlin: logger.info("The model is serialized in Marlin format. " "Using Marlin kernel.") - hf_quant_method = "marlin" + quant_method = "marlin" if self.quantization == "gptq": - self.quantization = hf_quant_method + self.quantization = quant_method if self.quantization is None: - self.quantization = hf_quant_method - elif self.quantization != hf_quant_method: + self.quantization = quant_method + elif self.quantization != quant_method: raise ValueError( "Quantization method specified in the model config " - f"({hf_quant_method}) does not match the quantization " + f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " f"({self.quantization}).")