Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Loading] Speedup model loading with distributed loading #3729

Closed
wants to merge 3 commits into from

Conversation

chestnut-Q
Copy link
Contributor

Hello! The current method for model loading is quite fixed, regardless of the tensor parallel size. It involves each rank in a tp group reading the full weight file, and then discarding the excess weight tensors if only a portion of the parameters is needed for that rank. When --tensor-parallel-size is greater than 1, most parameters require only 1/tp_size of the parameters, leading to significant additional weights IO.

Observing that the disk IO speed is slow (particularly for bin files), and the transfer rate between multiple GPUs is fast, we can adopt a distributed loading approach. This means each worker loads only 1/tp_size of the weight file (by file division, or for SafeTensors type, it can be by tensor division). Then the parameters needed by the workers are transferred to each other using torch.distributed.scatter or torch.distributed.broadcast. This approach can reduce disk IO to 1/tp_size.

I have implemented the example distributed loading code in llama.py and baichuan.py. I believe other models (if needed) can easily implement similar logic. To ensure compatibility with previous codes, the args introduced in this PR are optional. Therefore, if you do not wish to use distributed loading, the original code does not require any modifications.

When --tensor-parallel-size >= 4, the distributed loading method can significantly accelerate loading times, typically by 40% or more. Here are the experiment results on my machine (8*A100) for Llama-2-70b and Baichuan2-13B.

Llama-2-70b (TP8) Baichuan2-13B (TP4)
Vanilla 249.5s 45.3s
Distributed 141.5s 25.2s
Speedup 43.3% 44.4%

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! Thanks for the PR and I left some comments & questions. This is indeed a neat feature but please consider the use of utils & helper modules.

IMO ideally we would like to reduce the need of changing each model file as much as possible to make the codebase easier to maintain.

vllm/model_executor/layers/activation.py Outdated Show resolved Hide resolved
vllm/model_executor/models/baichuan.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/model_executor/models/baichuan.py Outdated Show resolved Hide resolved
vllm/model_executor/models/baichuan.py Outdated Show resolved Hide resolved
@chestnut-Q
Copy link
Contributor Author

@ywang96 Thank you for your suggestion! I have moved the duplicate codes to utils as per your advice, which resulted in minimal changes to each model file and made them clearer. Additionally, I have added the cli arg --use-distributed-loading and expanded support for other models that have a parameter structure similar to llama. If there are any more questions or suggestions, welcome to comment :)

@BenNR
Copy link

BenNR commented May 18, 2024

Hi guys this is a truly valuable feature is this still moving forward into an official vllm release? Not trying to be pushy great work and cool concept though!

@sdake
Copy link

sdake commented Sep 2, 2024

cc/ @sdake

@youkaichao
Copy link
Member

close as this becomes stale.

and please see #6127 (comment)

we recommend using safetensors format, and then we don't need this optimization.

@youkaichao youkaichao closed this Sep 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants