Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate LLaMA model loading #234

Merged
merged 7 commits into from
Aug 30, 2023
Merged

Accelerate LLaMA model loading #234

merged 7 commits into from
Aug 30, 2023

Conversation

JF-D
Copy link
Contributor

@JF-D JF-D commented Jun 25, 2023

This PR is for accelerating LLaMA model weights loading with safetensors. I find current load weight implementation doubles the time cost as the tensor-model parallelism increases (refer to the belowing loading time table for LLaMA-65B).

Parallelism Degree Original (minutes) Safetensors (minutes)
1 ~5 ~5
2 ~10 ~5
4 ~10 ~5

I think it is ready for review.
Code adapted from https://github.com/huggingface/text-generation-inference/blob/v0.8.2/server/text_generation_server/models/flash_llama.py#L206

@zhuohan123 zhuohan123 self-requested a review June 25, 2023 15:01
@AlpinDale
Copy link
Contributor

AlpinDale commented Jun 28, 2023

Oh sorry, didn't mean to do that. :P

@lucasjinreal
Copy link

@AlpinDale Can merge this? Currently model loading are extremly slow

@creatorrr
Copy link

@JF-D could you please add some comments to your changes? A tad hard to read them at the moment 😬

@JF-D
Copy link
Contributor Author

JF-D commented Jul 19, 2023

Resolve conflicts for reference.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JF-D Sorry for the long delay. A lot of people have actually asked for safetensor support and your PR looks great for LLaMA models! Do you think there is a possibility to extend what you did to all models by just modifying the hf_model_weights_iterator function?

@JF-D
Copy link
Contributor Author

JF-D commented Aug 5, 2023

@zhuohan123 I think it's possible, and I've updated the hf_model_weights_iterator function. Maybe you can review it.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, thank you for your great contribution! We tested the code last week and it works great! Left some questions and small suggestions. Once fixed, this PR should be able to be merged. BTW you can run format.sh --all to format your changes.

Comment on lines 329 to 338
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tp_size
if padded_vocab_size > self.config.vocab_size:
load_padded_tensor_parallel_vocab(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part a must-have change for safetensors, or is it another optimization? If latter, maybe we can include this part in another PR and keep this PR merely about loading safetensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can assume that the vocab will not be padded, this is not a must-have change. Can we have such assumption here?

vllm/model_executor/weight_utils.py Outdated Show resolved Hide resolved
Comment on lines 177 to 197
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: torch.Tensor or object,
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
break

# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]

param[:loaded_weight.shape[0]].copy_(loaded_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, we can exclude this function from this PR if it's not related to safetensors.

vllm/model_executor/weight_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution again! I will merge this PR first, and then add safetensor loading for other models in another PR.

@zhuohan123 zhuohan123 merged commit 0d93f15 into vllm-project:main Aug 30, 2023
2 checks passed
liuyanyi pushed a commit to liuyanyi/vllm that referenced this pull request Sep 12, 2023
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
@JF-D JF-D deleted the main branch March 5, 2024 07:18
@JF-D JF-D restored the main branch March 5, 2024 07:19
@JF-D JF-D deleted the main branch March 5, 2024 07:20
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
mht-sharma pushed a commit to mht-sharma/vllm that referenced this pull request Oct 30, 2024
)

* fix code path logic to load mllama model

* fix lint error

* fix lint error

---------

Co-authored-by: tjtanaa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants