-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CORE] Quantized lm-head Framework #4442
Changes from 158 commits
109291f
319eeed
d17428c
43dc07b
40dfa40
b422c4c
cf2ef3f
75d1aad
69963dd
3f348ca
6392659
91b655c
eb4252e
4b755a1
ce7f348
73c956d
cc1ed62
270169e
5e6c796
d667b23
806b8c2
7bcdf89
b02e1ce
2626a45
b18390f
786a49f
1c2c1a5
561b26c
2ec6b22
22297c1
a386d0c
7ac901d
ec88f23
a299189
200fc77
7ccffee
7247aa7
b343695
1b07243
49b0fa0
dd5835d
ea6819b
a51b26e
7866461
018b5a5
740a70b
b3b6187
ea2f9eb
aab9566
52b6aba
8a54dca
8fb68a2
193d093
894ee7b
9773b41
e1a8ec4
950d9da
e83a0c0
f3cdff9
0c18da0
1e03ad8
a120090
3229b6c
ccd0d54
0c9b5d4
6a857b4
90c70d2
b3f3670
65e9ceb
9716863
7b72ef0
83a0087
18f8f35
fd1cf0e
da0af51
589b50d
18e6996
97c2870
e3414df
81cc15c
232c09d
cd14c7a
1f0ad4c
83f3823
d94e6f8
d5a3667
8fde7f5
2f63a72
4bdf676
fd0a758
7ef2060
5b8bba9
3a461ec
5d30ab8
09cca74
2097e0d
9981620
a63c8a3
f8f705e
1e8bff6
d525139
282a703
770b091
51b0064
6deae36
5828457
c0551d5
732971e
584d644
eda6d40
37ede38
df5d5b5
c0cf2fa
43c67f1
746fd43
800d3f7
5b0a580
37795a6
ac5494e
e8e8612
5bbdb95
66b07b1
d755b4c
4fc9de5
9a699b5
58bbf5d
656ae2d
90ecd98
1916467
8981bcd
9e38196
b3ad6d5
db182dd
20a7fbf
90ef065
1990fff
8f3dae8
a8c5aeb
d29ba07
47cd399
92f8eb4
03db82c
c141c48
69f606a
85e4414
e04823d
1ada5c0
8666f8b
47b6b64
1e85d83
a7ebdb2
8a41ff2
9f8cf92
31ec9e3
280b4cc
f146569
4fed894
29a1fb1
57bc23a
fe87469
40fac6a
ab2c6a9
a0e937a
28fcecf
2ab00c0
d8306f7
c0ce169
09723d9
631f3a5
3a83edb
f1f9f48
3eae1ad
ef94034
a5a1adf
dfafb07
76633ed
15b3865
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Compares the outputs of hf vs vllm for medium sized models. | ||
|
||
There is not bitwise correctness for fp16 inference. | ||
As a result, in this test, we just confirm that the top selected tokens of the | ||
Marlin/GPTQ models are in the top 3 selections of each other. | ||
|
||
Run `pytest tests/models/test_models_medium_logprobs.py`. | ||
""" | ||
import pytest | ||
|
||
from tests.models.utils import check_logprobs_close | ||
|
||
MAX_MODEL_LEN = 1024 | ||
|
||
MODELS = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How long does it take to run all models listed here? Can some of them be removed to reduce the CI time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can just remove them. I just wanted to prove the accuracy was right |
||
# # arctic - skip: size in automation | ||
# "baichuan-inc/Baichuan2-7B-Chat", | ||
# "bigscience/bloom-560m", | ||
# # "THUDM/chatglm3-6b", skip: hf implementation broken | ||
# # commandr - skip: size in automation size | ||
# # dbrx - skip: size in automation | ||
# "Deci/DeciLM-7B-instruct", | ||
# # deepseek_v2 - skip: size in automation | ||
# "deepseek-ai/deepseek-coder-1.3b-instruct", | ||
# "tiiuae/falcon-rw-1b", | ||
# "google/gemma-1.1-2b-it", | ||
# # "google/gemma-2-9b-it", skip: not supported in transformers yet | ||
# "bigcode/tiny_starcoder_py", | ||
# "EleutherAI/gpt-j-6b", | ||
# "EleutherAI/pythia-410m", | ||
# "gpt2", | ||
# "internlm/internlm2-chat-7b", | ||
# # jais - skip: size in automation | ||
# "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
# "openbmb/MiniCPM-2B-128k", | ||
# # mixtral - skip: size in automation | ||
# "mosaicml/mpt-7b-instruct", | ||
# # "allenai/OLMo-1B", # skip: broken in transformers | ||
"facebook/opt-125m", | ||
# orion - skip: size in automation | ||
"microsoft/phi-2", | ||
"microsoft/Phi-3-small-8k-instruct", | ||
"Qwen/Qwen-1_8B", | ||
"Qwen/Qwen1.5-1.8B", | ||
"Qwen/Qwen2-0.5B-Instruct" | ||
# qwen2 moe - skip: size in automation | ||
"stabilityai/stablelm-2-1_6b-chat", | ||
"bigcode/starcoder2-3b", | ||
"xverse/XVERSE-7B", | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("num_logprobs", [5]) | ||
def test_models( | ||
vllm_runner, | ||
hf_runner, | ||
example_prompts, | ||
model, | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
) -> None: | ||
# Run HF. | ||
hf_model = hf_runner(model_name=model, dtype=dtype) | ||
hf_outputs = hf_model.generate_greedy_logprobs_limit( | ||
example_prompts, max_tokens, num_logprobs) | ||
del hf_model | ||
|
||
# Run vLLM. | ||
vllm_model = vllm_runner(model_name=model, | ||
dtype=dtype, | ||
max_model_len=MAX_MODEL_LEN, | ||
tensor_parallel_size=1) | ||
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, | ||
max_tokens, | ||
num_logprobs) | ||
del vllm_model | ||
|
||
check_logprobs_close( | ||
outputs_0_lst=hf_outputs, | ||
outputs_1_lst=vllm_outputs, | ||
name_0="hf", | ||
name_1="vllm", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Tests whether gptq models with quantized lm_head can be loaded. | ||
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`. | ||
""" | ||
from typing import Tuple | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod | ||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod | ||
from vllm.model_executor.layers.quantization.gptq_marlin import ( | ||
GPTQMarlinLinearMethod) | ||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod | ||
|
||
PROMPT = "On the surface of Mars, we found" | ||
|
||
MODELS_QUANT = [( | ||
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse", | ||
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), | ||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)] | ||
|
||
|
||
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT) | ||
def test_lm_head( | ||
vllm_runner, | ||
model_lm_head_quant: Tuple[str, bool], | ||
) -> None: | ||
model, lm_head_quantized = model_lm_head_quant | ||
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048) | ||
|
||
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker. | ||
model_runner.model.lm_head) | ||
|
||
if lm_head_quantized: | ||
assert isinstance( | ||
lm_head_layer.linear_method, | ||
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) | ||
else: | ||
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod) | ||
|
||
print( | ||
vllm_model.generate_greedy(prompts=["Hello my name is"], | ||
max_tokens=10)[0][1]) | ||
del vllm_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we have to re-download Falcon?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trust_remote_code
did not work for falcon, not sure why