-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate LLaMA model loading #234
Conversation
Oh sorry, didn't mean to do that. :P |
@AlpinDale Can merge this? Currently model loading are extremly slow |
@JF-D could you please add some comments to your changes? A tad hard to read them at the moment 😬 |
Resolve conflicts for reference. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JF-D Sorry for the long delay. A lot of people have actually asked for safetensor support and your PR looks great for LLaMA models! Do you think there is a possibility to extend what you did to all models by just modifying the hf_model_weights_iterator
function?
@zhuohan123 I think it's possible, and I've updated the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, thank you for your great contribution! We tested the code last week and it works great! Left some questions and small suggestions. Once fixed, this PR should be able to be merged. BTW you can run format.sh --all
to format your changes.
vllm/model_executor/models/llama.py
Outdated
if "embed_tokens" in name or "lm_head" in name: | ||
param = state_dict[name] | ||
# Consider padding in the vocab size. | ||
padded_vocab_size = param.shape[0] * tp_size | ||
if padded_vocab_size > self.config.vocab_size: | ||
load_padded_tensor_parallel_vocab(param, loaded_weight, name, | ||
self._column_parallel_weights, | ||
self._row_parallel_weights, | ||
tensor_model_parallel_rank) | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part a must-have change for safetensors, or is it another optimization? If latter, maybe we can include this part in another PR and keep this PR merely about loading safetensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can assume that the vocab will not be padded, this is not a must-have change. Can we have such assumption here?
vllm/model_executor/weight_utils.py
Outdated
def load_padded_tensor_parallel_vocab( | ||
param: torch.Tensor, | ||
loaded_weight: torch.Tensor or object, | ||
param_name: str, | ||
column_parallel_weight_names: List[str], | ||
row_parallel_weight_names: List[str], | ||
tensor_model_parallel_rank: int, | ||
) -> None: | ||
for p in column_parallel_weight_names: | ||
if p in param_name: | ||
shard_size = param.shape[0] | ||
start_idx = tensor_model_parallel_rank * shard_size | ||
end_idx = (tensor_model_parallel_rank + 1) * shard_size | ||
loaded_weight = loaded_weight[start_idx:end_idx] | ||
break | ||
|
||
# convert PySafeSlice object to torch.Tensor | ||
if not isinstance(loaded_weight, torch.Tensor): | ||
loaded_weight = loaded_weight[:] | ||
|
||
param[:loaded_weight.shape[0]].copy_(loaded_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, we can exclude this function from this PR if it's not related to safetensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution again! I will merge this PR first, and then add safetensor loading for other models in another PR.
) * fix code path logic to load mllama model * fix lint error * fix lint error --------- Co-authored-by: tjtanaa <[email protected]>
This PR is for accelerating LLaMA model weights loading with safetensors. I find current load weight implementation doubles the time cost as the tensor-model parallelism increases (refer to the belowing loading time table for LLaMA-65B).
I think it is ready for review.
Code adapted from https://github.com/huggingface/text-generation-inference/blob/v0.8.2/server/text_generation_server/models/flash_llama.py#L206