From a5a6b0bd3eebb0efa55f467f2d71dc22e4fb7ff9 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 30 May 2024 05:58:37 -0700 Subject: [PATCH] [Bugfix] Automatically Detect SparseML models (#5119) --- vllm/config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4b256d00a32df..4d05b4ea36d5c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -156,6 +156,17 @@ def _verify_embedding_mode(self) -> None: self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # SparseML uses a "compression_config" with a "quantization_config". + compression_cfg = getattr(self.hf_config, "compression_config", + None) + if compression_cfg is not None: + quant_cfg = compression_cfg.get("quantization_config", None) + + return quant_cfg + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -163,12 +174,13 @@ def _verify_quantization(self) -> None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - quant_cfg = getattr(self.hf_config, "quantization_config", None) + quant_cfg = self._parse_quant_hf_config() + if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for name, method in QUANTIZATION_METHODS.items(): + for _, method in QUANTIZATION_METHODS.items(): quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: