Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Automatically Detect SparseML models (vllm-project#5119)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Jul 14, 2024
1 parent ed71c6b commit fbab69a
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,31 @@ def _verify_embedding_mode(self) -> None:
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)

def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# SparseML uses a "compression_config" with a "quantization_config".
compression_cfg = getattr(self.hf_config, "compression_config",
None)
if compression_cfg is not None:
quant_cfg = compression_cfg.get("quantization_config", None)

return quant_cfg

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

# Parse quantization method from the HF model config, if available.
quant_cfg = getattr(self.hf_config, "quantization_config", None)
quant_cfg = self._parse_quant_hf_config()

if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for name, method in QUANTIZATION_METHODS.items():
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
Expand Down

0 comments on commit fbab69a

Please sign in to comment.