Skip to content

Commit

Permalink
TP/quantization/weight loading refactor part 2 - Refactor quantized l…
Browse files Browse the repository at this point in the history
…inear logic and extend quantization support to all models (vllm-project#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
  • Loading branch information
zhuohan123 authored Nov 16, 2023
1 parent 3a632f8 commit e3d1abb
Show file tree
Hide file tree
Showing 36 changed files with 2,159 additions and 2,508 deletions.
49 changes: 30 additions & 19 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
Expand All @@ -155,23 +155,34 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
# For ChatGLM-2:
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
return (self.hf_config.multi_query_group_num //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads

# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None:

self._request_streams[request_id].finish()

def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
new_requests: List[Dict] = []
finished_requests: Set[str] = set()

while not self._finished_requests.empty():
Expand Down
Loading

0 comments on commit e3d1abb

Please sign in to comment.