Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Quantized lm-head Framework #4442

Merged
merged 177 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
177 commits
Select commit Hold shift + click to select a range
109291f
skip unknown param g_idx for gptq
Qubitium Apr 26, 2024
319eeed
skip unknown param g_idx for gptq
Qubitium Apr 26, 2024
d17428c
skip unknown param g_idx for gptq
Qubitium Apr 26, 2024
43dc07b
debug
Qubitium Apr 26, 2024
40dfa40
update
Qubitium Apr 26, 2024
b422c4c
update
Qubitium Apr 26, 2024
cf2ef3f
update
Qubitium Apr 26, 2024
75d1aad
test pass linear_method
Qubitium Apr 27, 2024
69963dd
test pass linear_method
Qubitium Apr 27, 2024
3f348ca
pass linear_method
Qubitium Apr 28, 2024
6392659
req update
Qubitium Apr 28, 2024
91b655c
fix issue: can't load quant lm_head weights
Apr 28, 2024
eb4252e
call lm_head.linear_method.apply_weights()
Apr 28, 2024
4b755a1
Merge branch 'main' into autogptq-lm-head
Apr 28, 2024
ce7f348
use quant_config.get_quant_method()
Apr 28, 2024
73c956d
lm_head is VocabParallelEmbedding not LinearBase
Qubitium Apr 28, 2024
cc1ed62
fix error: Overwriting existing tensor attribute: weight_loader
Apr 28, 2024
270169e
lm_head is VocabParallelEmbedding not LinearBase
Qubitium Apr 28, 2024
5e6c796
merge
Qubitium Apr 28, 2024
d667b23
merge
Qubitium Apr 28, 2024
806b8c2
set separate_bias_add=True for unquant linear_method
Qubitium Apr 28, 2024
7bcdf89
cleanup
Qubitium Apr 28, 2024
b02e1ce
try to mimic original logic
Qubitium Apr 28, 2024
2626a45
use linear_method only
Qubitium Apr 28, 2024
b18390f
keep logic as close to main as possible
Qubitium Apr 28, 2024
786a49f
wrong name
Qubitium Apr 28, 2024
1c2c1a5
move linear_method decoding to ParallelLmHead class
Qubitium Apr 28, 2024
561b26c
fix error: UnquantizedLinearMethod.create_weights() missing 1 require…
Apr 28, 2024
2ec6b22
Fix get embedding error
Apr 28, 2024
22297c1
revert req.txt to main
Qubitium Apr 29, 2024
a386d0c
rename
Qubitium Apr 29, 2024
7ac901d
add QUANTIZED bool param to base LinearMethod so we can avoid all the…
Qubitium Apr 29, 2024
ec88f23
wrong state
Qubitium Apr 29, 2024
a299189
pep
Qubitium Apr 29, 2024
200fc77
override QUANTIZED property to True for all quant based linear methods
Qubitium Apr 29, 2024
7ccffee
comment TODO
Qubitium Apr 29, 2024
7247aa7
pass quant_config to base class and init linear_method there
Qubitium Apr 29, 2024
b343695
ruff
Qubitium Apr 29, 2024
1b07243
ruff
Qubitium Apr 29, 2024
49b0fa0
de-ruff
Qubitium Apr 29, 2024
dd5835d
revert newline
Qubitium Apr 29, 2024
ea6819b
fix bad class name QuantizeMethodBase to QuantizableMethodBase
Qubitium Apr 29, 2024
a51b26e
Merge branch 'main' into autogptq-lm-head
Qubitium Apr 29, 2024
7866461
The output_size parameter of create_weights() should be total embeddings
Apr 29, 2024
018b5a5
missed rename
Qubitium Apr 29, 2024
740a70b
isort
Qubitium Apr 29, 2024
b3b6187
clean
Qubitium Apr 29, 2024
ea2f9eb
modify weight_loader()
Apr 29, 2024
aab9566
format
Qubitium Apr 29, 2024
52b6aba
short method name and apply to all models
Qubitium Apr 29, 2024
8a54dca
move method into correct util
Qubitium Apr 29, 2024
8fb68a2
add test_quant_lm_head_layer()
Apr 29, 2024
193d093
pass lm_head and not lm_head.weights
Qubitium Apr 29, 2024
894ee7b
Merge remote-tracking branch 'origin/autogptq-lm-head' into autogptq-…
Qubitium Apr 29, 2024
9773b41
add test_quant_lm_head_false()
Apr 29, 2024
e1a8ec4
do not change opt/lm_head type
Qubitium Apr 29, 2024
950d9da
Merge remote-tracking branch 'origin/autogptq-lm-head' into autogptq-…
Qubitium Apr 29, 2024
e83a0c0
pass test_logits_processor
Qubitium Apr 29, 2024
f3cdff9
enforce_eager for more stable ci test
Qubitium Apr 29, 2024
0c18da0
revert model lm_head type to original
Qubitium Apr 29, 2024
1e03ad8
standardize lm_head for ParallelLMHead
Qubitium Apr 29, 2024
a120090
ruff/import
Qubitium Apr 29, 2024
3229b6c
format
Qubitium Apr 29, 2024
ccd0d54
Revert "standardize lm_head for ParallelLMHead"
Qubitium Apr 29, 2024
0c9b5d4
fix oom in test
Qubitium Apr 29, 2024
6a857b4
fix starcoder2
Qubitium Apr 29, 2024
90c70d2
fix lora
Qubitium Apr 29, 2024
b3f3670
fix lora test
Qubitium Apr 29, 2024
65e9ceb
pass quant_config to all ParallelLMHead
Qubitium Apr 29, 2024
9716863
comment
Qubitium Apr 30, 2024
7b72ef0
expand testing for refactor
robertgshaw2-neuralmagic Apr 30, 2024
83a0087
Merge remote-tracking branch 'upstream/main' into expand-testing
robertgshaw2-neuralmagic Apr 30, 2024
18f8f35
expanded testing
robertgshaw2-neuralmagic Apr 30, 2024
fd1cf0e
raise error if lm_head is quantized but layer is not ParallelLMHead
Qubitium Apr 30, 2024
da0af51
merge lm_false and _true into one test file
Apr 30, 2024
589b50d
formatting
robertgshaw2-neuralmagic Apr 30, 2024
18e6996
format
Qubitium Apr 30, 2024
97c2870
Merge pull request #2 from neuralmagic/expand-testing
Qubitium Apr 30, 2024
e3414df
Revert "raise error if lm_head is quantized but layer is not Parallel…
Qubitium Apr 30, 2024
81cc15c
fix test
Qubitium Apr 30, 2024
232c09d
Skip loading of weights of lm_head layer.
Apr 30, 2024
cd14c7a
test for opt-125M
Apr 30, 2024
1f0ad4c
rename VocabParallelEmbedding to ParallelVocabEmbedding for consistency
Qubitium Apr 30, 2024
83f3823
Merge remote-tracking branch 'origin/autogptq-lm-head' into autogptq-…
Qubitium Apr 30, 2024
d94e6f8
fix opt not loading quantized lm_head
Qubitium Apr 30, 2024
d5a3667
Merge branch 'main' into autogptq-lm-head
robertgshaw2-neuralmagic Apr 30, 2024
8fde7f5
updated test, opt model is failing
robertgshaw2-neuralmagic Apr 30, 2024
2f63a72
fix opt and disable marlin when lm_head quant is enabled. TODO find o…
Qubitium May 1, 2024
4bdf676
Merge remote-tracking branch 'origin/autogptq-lm-head' into autogptq-…
Qubitium May 1, 2024
fd0a758
remove debug prints
Qubitium May 1, 2024
7ef2060
isort
Qubitium May 1, 2024
5b8bba9
add non-quantized opt-125m to test to ensure that quant code did not …
Qubitium May 1, 2024
3a461ec
1. move lm_head_quantized proper into base QuantizationConfig class.
Qubitium May 1, 2024
5d30ab8
format
Qubitium May 1, 2024
09cca74
fix marlin loading of lm_head. remove TODO
Qubitium May 3, 2024
2097e0d
add comments for weight_loader
May 6, 2024
9981620
Replace BaseConfig.is_lm_head_quantized property with BaseConfig.is_l…
Qubitium May 6, 2024
a63c8a3
Merge remote-tracking branch 'origin/autogptq-lm-head' into autogptq-…
Qubitium May 6, 2024
f8f705e
implement BaseConfig.is_lm_head_quantized() for all non-gptq quant co…
Qubitium May 6, 2024
1e8bff6
fix lora logits bias computed twice
Qubitium May 6, 2024
d525139
format
Qubitium May 6, 2024
282a703
merged - removed quantized lm head for tied embeddings
Jun 29, 2024
770b091
format
Jun 29, 2024
51b0064
format passing
Jun 29, 2024
6deae36
fix all parallel-vocab-embedding --> vocab-parallel-embedding
Jun 29, 2024
5828457
quantizable-method-base --> quantize-method-base
Jun 29, 2024
c0551d5
spurious change
Jun 29, 2024
732971e
cleanup AQLM
Jun 29, 2024
584d644
missed nit
Jun 29, 2024
eda6d40
fix up awq
Jun 29, 2024
37ede38
fix up fp8
Jun 29, 2024
df5d5b5
nit
Jun 29, 2024
c0cf2fa
cleanup gptq marlin
Jun 29, 2024
43c67f1
cleanup more
Jun 29, 2024
746fd43
nit
Jun 29, 2024
800d3f7
remove QUANTIZED
Jun 29, 2024
5b0a580
if -- elif
Jun 29, 2024
37795a6
removed
Jun 29, 2024
ac5494e
add kwarg
Jun 29, 2024
e8e8612
cleanup
Jun 29, 2024
5bbdb95
more cleanup
Jun 29, 2024
66b07b1
nit
Jun 29, 2024
d755b4c
use linear_method.create_weights
Jun 29, 2024
4fc9de5
fix spurious changes
Jun 29, 2024
9a699b5
fix typo
Jun 29, 2024
58bbf5d
remove another spurious change
Jun 29, 2024
656ae2d
format
Jun 29, 2024
90ecd98
format fix
Jun 30, 2024
1916467
passing linting
Jun 30, 2024
8981bcd
fix more changes
Jun 30, 2024
9e38196
spurious newline
Jun 30, 2024
b3ad6d5
spurious change in qwen
Jun 30, 2024
db182dd
first end to end run of llama working!
Jun 30, 2024
20a7fbf
format
Jun 30, 2024
90ef065
cleanup
Jun 30, 2024
1990fff
Merge branch 'main' into autogptq-lm-head
Jun 30, 2024
8f3dae8
format
Jun 30, 2024
a8c5aeb
fix opt
Jun 30, 2024
d29ba07
fix gpt2
Jun 30, 2024
47cd399
fix arctic
Jun 30, 2024
92f8eb4
fix gptbigcode
Jun 30, 2024
03db82c
actually fix gptbigcode
Jun 30, 2024
c141c48
commandr
Jun 30, 2024
69f606a
fix deepseekv2
Jun 30, 2024
85e4414
fix falcon
Jun 30, 2024
e04823d
fix gemma2
Jun 30, 2024
1ada5c0
fix jais
Jun 30, 2024
8666f8b
fix mpt
Jun 30, 2024
47b6b64
fix phi3 small
Jun 30, 2024
1e85d83
fix vision models
Jun 30, 2024
a7ebdb2
all models should run
Jun 30, 2024
8a41ff2
formatted
Jun 30, 2024
9f8cf92
cleanup lora test changes
Jun 30, 2024
31ec9e3
cleanup tests
Jun 30, 2024
280b4cc
nit
Jun 30, 2024
f146569
remove llama debug
Jun 30, 2024
4fed894
fix spurious change in qwen2
Jun 30, 2024
29a1fb1
formats again
Jun 30, 2024
57bc23a
re-enabled tests
Jun 30, 2024
fe87469
format
Jun 30, 2024
40fac6a
skip fp16 tests in non cuda automation
Jun 30, 2024
ab2c6a9
fix mlp-speculator
Jun 30, 2024
a0e937a
fix spec decode tests
Jun 30, 2024
28fcecf
passing mlp correctness with precision of float32
Jun 30, 2024
2ab00c0
Merge branch 'upstream-main' into autogptq-lm-head
robertgshaw2-neuralmagic Jun 30, 2024
d8306f7
Merge branch 'upstream-main' into autogptq-lm-head
robertgshaw2-neuralmagic Jun 30, 2024
c0ce169
format
robertgshaw2-neuralmagic Jun 30, 2024
09723d9
Merge branch 'main' into autogptq-lm-head
robertgshaw2-neuralmagic Jul 1, 2024
631f3a5
Address cody's comments
robertgshaw2-neuralmagic Jul 1, 2024
3a83edb
Merge branch 'autogptq-lm-head' of https://github.com/Qubitium/vllm i…
robertgshaw2-neuralmagic Jul 1, 2024
f1f9f48
format
robertgshaw2-neuralmagic Jul 1, 2024
3eae1ad
updated names
robertgshaw2-neuralmagic Jul 1, 2024
ef94034
format
robertgshaw2-neuralmagic Jul 1, 2024
a5a1adf
Merge branch 'main' into autogptq-lm-head
robertgshaw2-neuralmagic Jul 1, 2024
dfafb07
remove too long tests
robertgshaw2-neuralmagic Jul 2, 2024
76633ed
cleanup tests
robertgshaw2-neuralmagic Jul 2, 2024
15b3865
updated
robertgshaw2-neuralmagic Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def _pretest():

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)

original_weight = linear.weight.clone()
original_lm_head = deepcopy(linear)

linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
Expand All @@ -490,7 +490,7 @@ def _pretest():
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
Expand Down Expand Up @@ -519,11 +519,11 @@ def _pretest():

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)

rtol, atol = TOLERANCES[lora_result.dtype]
Expand Down
45 changes: 45 additions & 0 deletions tests/quantization/test_lm_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests whether gptq models with quantized lm_head can be loaded.
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from typing import Tuple

import pytest
import torch

from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod

PROMPT = "On the surface of Mars, we found"

MODELS_QUANT = [(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]


@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
def test_lm_head(
vllm_runner,
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)

lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)

if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)

print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
MAX_SPEC_TOKENS = 5

# precision
PRECISION = "float16"
PRECISION = "float32"
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def pick_ith(token_ids, logits):
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,11 +1172,11 @@ def set_mapping(
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata


Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self,

def forward(
self,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
Expand All @@ -52,8 +54,7 @@ def forward(
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)

logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
Expand All @@ -68,12 +69,13 @@ def forward(

return logits

def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")

@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default

@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs


Expand All @@ -24,10 +25,12 @@ def __init__(
weight_bits: int,
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
Expand All @@ -37,7 +40,8 @@ def __init__(
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")

@classmethod
def get_name(cls) -> str:
Expand All @@ -61,11 +65,14 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
return None

Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.utils import get_device_capability_stateless

logger = init_logger(__name__)
Expand Down Expand Up @@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""

def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
Expand All @@ -69,6 +70,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized

# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
Expand Down Expand Up @@ -96,7 +98,8 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -120,7 +123,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -145,7 +151,8 @@ def override_quantization_method(cls, hf_quant_cfg,
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self)
return None

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)
Expand All @@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
Expand All @@ -51,7 +54,8 @@ def __init__(
self.perm_len = 1024

def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size})"
return (f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -73,7 +77,9 @@ def get_config_filenames(cls) -> List[str]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(group_size, lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -96,7 +102,8 @@ def override_quantization_method(cls, hf_quant_cfg,

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self)
return None

Expand Down
Loading
Loading