Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models #1622

Merged
merged 53 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6541618
Create linear method
zhuohan123 Nov 3, 2023
a97ede8
Support llama with the new quantization scheme
zhuohan123 Nov 3, 2023
4671286
make awq work
zhuohan123 Nov 3, 2023
4579d67
Fix squeezellm
zhuohan123 Nov 3, 2023
4406447
Remove unused codes
zhuohan123 Nov 3, 2023
5a535e3
Fix mistral
zhuohan123 Nov 3, 2023
14e66f8
Fix format
zhuohan123 Nov 3, 2023
f464375
New weight loading method, working for llama
zhuohan123 Nov 8, 2023
a5852ef
Fix awq loading
zhuohan123 Nov 8, 2023
7bf933f
Fix squeeze llm
zhuohan123 Nov 8, 2023
8af8b60
fix quantization
zhuohan123 Nov 8, 2023
686dafb
new weight loader
zhuohan123 Nov 9, 2023
e474020
Fix vocab loading
zhuohan123 Nov 9, 2023
d107613
clean up llama loader
zhuohan123 Nov 9, 2023
d4aa8c9
fix awq
zhuohan123 Nov 9, 2023
f48381b
wip fix squeezellm
zhuohan123 Nov 9, 2023
c5a9f9c
fix squeeze llm
zhuohan123 Nov 9, 2023
92155da
fix weight loader for embedding
zhuohan123 Nov 9, 2023
e528dbc
fix
zhuohan123 Nov 9, 2023
772ab72
support mistral
zhuohan123 Nov 9, 2023
0a08e66
fix
zhuohan123 Nov 9, 2023
7d7aa4b
Fix aqulia
zhuohan123 Nov 9, 2023
1df5d6b
fix vocab loader
zhuohan123 Nov 9, 2023
93685f4
fix baichuan
zhuohan123 Nov 9, 2023
5f5ea90
fix bloom
zhuohan123 Nov 9, 2023
31af3ea
fix qwen
zhuohan123 Nov 10, 2023
68f5a3f
fix qwen
zhuohan123 Nov 10, 2023
4f68d07
fix opt
zhuohan123 Nov 10, 2023
23099e2
fix mpt
zhuohan123 Nov 10, 2023
d7d108d
fix internlm
zhuohan123 Nov 10, 2023
ed44156
fix gpt2
zhuohan123 Nov 10, 2023
a75dea1
fix gpt neox
zhuohan123 Nov 10, 2023
1f6ca33
fix gptj
zhuohan123 Nov 10, 2023
b118a2f
fix falcon
zhuohan123 Nov 10, 2023
fb595c7
clean up
zhuohan123 Nov 10, 2023
d5ffe88
Fix GPT Bigcode
zhuohan123 Nov 11, 2023
036bee8
Merge branch 'main' into refactor-quantization
zhuohan123 Nov 11, 2023
7acf443
Fix chatglm and yi models
zhuohan123 Nov 11, 2023
c33e0f0
format
zhuohan123 Nov 11, 2023
63af93c
Simplify code logic
zhuohan123 Nov 11, 2023
82c76b1
Simplify code
zhuohan123 Nov 11, 2023
f53469b
fix
zhuohan123 Nov 11, 2023
d4c0798
Add comment for linear.py
zhuohan123 Nov 11, 2023
f0e7f44
Add comments
zhuohan123 Nov 11, 2023
247252c
code cleanup
zhuohan123 Nov 11, 2023
dfb4a81
Add comment
zhuohan123 Nov 11, 2023
79a6a9a
Merge branch 'main' into refactor-quantization
zhuohan123 Nov 15, 2023
f750166
Fix review comments
zhuohan123 Nov 16, 2023
fd4f4d5
fix naming
zhuohan123 Nov 16, 2023
a7dd7f4
fix comment
zhuohan123 Nov 16, 2023
18898f7
rename
zhuohan123 Nov 16, 2023
2d01ce0
Fix issues in PR #1640
zhuohan123 Nov 16, 2023
241bfa8
Fix config
zhuohan123 Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
Expand All @@ -155,23 +155,34 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
# For ChatGLM-2:
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
return (self.hf_config.multi_query_group_num //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads

# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None:

self._request_streams[request_id].finish()

def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
new_requests: List[Dict] = []
finished_requests: Set[str] = set()

while not self._finished_requests.empty():
Expand Down
Loading
Loading