Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-lora support for Triton vLLM backend #23

Merged
merged 28 commits into from
Apr 18, 2024

Conversation

l1cacheDell
Copy link
Contributor

@l1cacheDell l1cacheDell commented Nov 29, 2023

Hello, the vllm recently support multi-lora feature, see this PR for more information.

To support request for lora, I implemented the code for lora request in triton backend.


The vLLM scheme

Now (2023.11.29) vllm supports loading local weights to model backbone, (no auto-download from S3/Hugging Face hub), they claimed that they will implement the remote-weights downloading later.

Besides, the LoRA path should be configured from sender, for more you can see there interface.


Triton scheme

Under the circumstances, the triton should act as a transit point for user's request processing and converting.

For users, all they need to do is to provide a target lora name for querying:

python3 client_lora.py -l alpaca

Then Triton will check if the target lora is in the local lora weights or not. If true, Triton will wrap this target lora information to a new object LoRARequest and send to AsyncEngine to execute.

See model.py for more information.


Changes to model repository

The model.json needs new args such as "enable_lora": "true" to support multi-lora in vllm, see EngineArgs for more.

The multi_lora.json file was created to figure out the specific lora name for each lora adapter. model.json stores the information of vllm engine initialization, so information of lora weight path cannot be stored in that. I created a new file to manage all local lora weights.

It can be configured like this:

{
    "alpaca": "/vllm_workspace/weights/loras/alpaca-lora-7b",
    "bactrian": "/vllm_workspace/weights/loras/bactrian-x-llama-7b-lora"
}

Triton backend will read this file when initializing as a repository for local lora weights.


Test

The code passed the tests with and without lora support. Not setting enable_lora does not affect the original code.

python3 client_lora.py

Hello, my name is
I am a 20 year old student from the Netherlands. I am currently

python3 client_lora.py -l alpaca

Hello, my name is
I am a student at
I am currently studying
I am interested in learning

For more information about deployment, you can see the docs/llama_multi_lora_tutorial.md.

src/model.py Outdated

import numpy as np
import triton_python_backend_utils as pb_utils
from vllm import SamplingParams
from vllm.sampling_params import SamplingParams # MODIFY: bug fix, it won't bother
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please clarify this bug?

Copy link
Contributor Author

@l1cacheDell l1cacheDell Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using pip install -e . to install vLLM from their source code (to test multi-lora version), in this case, using from vllm import SamplingParams will throw ImportError:

ImportError: cannot import name 'SamplingParams' from 'vllm' (unknown location)

So I just tried to import the SamplingParams from the source code directly, to avoid this problem.


I think this modification will not affect either common users or developers, and solve the problem caused by ImportError.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll keep this conversation unresolved for history

src/model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@oandreeva-nv oandreeva-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing job on this PR! I've noticed that you have a lot o MODIFY comments in model.py, is this PR still in progress? Feel free to tag me, when it is done, I'll review

@l1cacheDell
Copy link
Contributor Author

l1cacheDell commented Dec 28, 2023

@oandreeva-nv Thank you very much for your thorough review of my code and for providing me with detailed guidance and suggestions.

I marked MODIFY to show that code changes in this line when developing. I have deleted them in the latest commit. The code work is done.

More changes are made to the code and the docs based on your suggestions.

Note that vllm multi-lora feature may not come with container 23.12 ( -> vllm 0.2.3). I will keep working on multi-lora support on backend, and fix bug or offer new features.

@l1cacheDell
Copy link
Contributor Author

l1cacheDell commented Dec 29, 2023

2023-12-29 update:

  • Yard1's multi-lora PR hasn't been merged to main branch of vLLM yet, so I edited the llama_multi_lora_tutorial.md to offer my way of installing vLLM with multi-lora feature.
    So more developers can have easier and clear ways to get on with it.
  • In the latest commit, I fix the bug of non-graceful terminate error, the original code lacks something in request error handling. When terminating the triton server after LoRA request error response, it will trigger segment fault:
E1229 07:55:42.707509 6674 main.cc:517] failed to stop server: Internal - Exit timeout expired. Exiting immediately.
root@lee-MS-7D25:/vllm_workspace# I1229 07:55:44.442169 6737 pb_stub.cc:1815]  Non-graceful termination detected. 
*** SIGSEGV received at time=1703836544 on cpu 4 ***
PC: @     0x7f7311871a84  (unknown)  (unknown)
    @     0x7f7311817520  (unknown)  (unknown)
[2023-12-29 07:55:44,479 E 6737 6737] logging.cc:361: *** SIGSEGV received at time=1703836544 on cpu 4 ***
[2023-12-29 07:55:44,479 E 6737 6737] logging.cc:361: PC: @     0x7f7311871a84  (unknown)  (unknown)
[2023-12-29 07:55:44,480 E 6737 6737] logging.cc:361:     @     0x7f7311817520  (unknown)  (unknown)
Fatal Python error: Segmentation fault

Stack (most recent call first):
  <no Python frame>

and after fixing, the server will shut down smoothly:

I1229 08:03:17.533748 7999 server.cc:331] Timeout 29: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:17.536050 7999 model.py:362] [vllm] Issuing finalize to vllm backend
I1229 08:03:18.533890 7999 server.cc:331] Timeout 28: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:19.534038 7999 server.cc:331] Timeout 27: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:19.536044 7999 model.py:190] [vllm] Shutdown complete
(RayWorkerVllm pid=9004) INFO 12-29 08:02:44 model_runner.py:567] Graph capturing finished in 35 secs.
(RayWorkerVllm pid=9004) [W CUDAGraph.cpp:145] Warning: Waiting for pending NCCL work to finish before starting graph capture. (function operator())
I1229 08:03:20.534161 7999 server.cc:331] Timeout 26: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:21.534340 7999 server.cc:331] Timeout 25: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:22.534539 7999 server.cc:331] Timeout 24: Found 1 live models and 0 in-flight non-inference requests
I1229 08:03:22.783541 7999 model_lifecycle.cc:603] successfully unloaded 'vllm_model' version 1
I1229 08:03:23.534724 7999 server.cc:331] Timeout 23: Found 0 live models and 0 in-flight non-inference requests
root@lee-MS-7D25:/vllm_workspace#

Lastly, hope this PR can get merged. :) @oandreeva-nv

@germanjke
Copy link

which model.py we should update? multi-lora feature
this one https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/src/model.py looks like have no multi-lora support

@germanjke
Copy link

ok right now evertyhing is fine, if I clone main vllm last version, not some branch which you posted in docs

@l1cacheDell
Copy link
Contributor Author

Everything in this PR is current with the latest vLLM version. I've revised the documentation to include instructions on installing vLLM with punica kernel compilation.

src/model.py Outdated Show resolved Hide resolved
samples/client_lora.py Outdated Show resolved Hide resolved
src/model.py Outdated Show resolved Hide resolved
src/model.py Fixed Show fixed Hide fixed
@oandreeva-nv
Copy link
Collaborator

@SamuraiBUPT , may I also ask to fix some pre-commit issues?

The easiest way is to install pre-commit : https://pre-commit.com/
and run pre-commit run --all-files

Some spelling issues should be fixed manually

@l1cacheDell
Copy link
Contributor Author

@oandreeva-nv Modifications have been made, and I have run pre-commit check locally and fix some issues as well.

Leave an issue about sending LoRA error.

@oandreeva-nv
Copy link
Collaborator

Created pipeline (id: 14152157 )

@oandreeva-nv
Copy link
Collaborator

Hi @SamuraiBUPT, may I ask to re-base your PR on top of a main branch? We've recently updated vllm version to 0.4.0 and it required some test changes

@l1cacheDell
Copy link
Contributor Author

@oandreeva-nv done :)

@oandreeva-nv
Copy link
Collaborator

re-running CI

@l1cacheDell
Copy link
Contributor Author

I apologize for my ignorance; I did not fully adhere to the contribution process outlined in the Triton Contribution Guidelines initially.

However, I have now submitted a CLA to [email protected].

docs/llama_multi_lora_tutorial.md Outdated Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Outdated Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Outdated Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Outdated Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Show resolved Hide resolved
ci/L0_backend_vllm/test.sh Show resolved Hide resolved
docs/llama_multi_lora_tutorial.md Show resolved Hide resolved
samples/client_lora.py Outdated Show resolved Hide resolved
@l1cacheDell
Copy link
Contributor Author

@oandreeva-nv @tanmayv25 Hello, I have addressed the following issues in the new commit:

  1. Updated copyright comments in the documentation and the copyright year in test scripts.
  2. Revised the tutorial document to clarify NGC version, vllm version, and Step 6: send a request, along with some wording enhancements.
  3. Modified the CI test scripts to use gemma-2b and its two related LoRA adapters, aligning the CI test code with current requirements.
  4. Removed client_lora.py to eliminate redundant code.

Regarding the other issues:

  • Default installation of VLLM_INSTALL_PUNICA_KERNELS in container images: I recommend enabling it by default. I have uninstalled vllm v0.3.0 and installed vllm v0.4.0.post1 from source, which took approximately 10 minutes and 31 seconds on a 12th Gen Intel(R) Core(TM) i9-12900K. The installation time might vary for other developers.
  • Changes to local lora weights applying? No changes in v0.4.0.post1; it does not support dynamic adapter weight fetching from Hugging Face. If there are updates to the weight loading process, please inform me.
  • Transitioning CI tests from llama to gemma: Understanding the space constraints in our CI test environment, I have reviewed all supported models and LoRAs as listed in the vllm documentation Supported Models. Gemma-2b is the smallest model (compatible with LoRA) I can find. I have opted for this model hoping it aligns well with our CI requirements.

By the way, I don't know where our pipeline is. I greatly appreciate you providing the pipeline ID, but could you also provide the location of the pipeline, or the platform it runs on?

@oandreeva-nv
Copy link
Collaborator

oandreeva-nv commented Apr 15, 2024

Hi @SamuraiBUPT , thanks for the thorough review! I'll re-run the pipeline and let you know regarding results.

At the moment CI/CD process is internal and external contributors do not have an access to it, thus maintainers of the Triton codebase are responsible for making sure CI/CD has passed for all external PRs.

I've started a new CI pipeline: 14260916 14282742.

@tanmayv25
Copy link
Contributor

tanmayv25 commented Apr 15, 2024

@SamuraiBUPT Thanks for resolving my questions and excellent contribution! The code changes looks good to me. I will defer to @oandreeva-nv for final approval.

We will enable the PUNICA kernels in our container image as a follow-up item.

@l1cacheDell
Copy link
Contributor Author

HI @oandreeva-nv could you please let me know how it's going? If there are any issues with the CI test results that require bug fixes, I can help resolve them. I appreciate your time. :)

@oandreeva-nv
Copy link
Collaborator

Hi @SamuraiBUPT , apologies for the delay, our initial tests were running on V100s, but it turned out that v100s are too old. I did some adjustments to the test infrastructure to run on A100. Results should be ready soon. Again, apologies for the delay

@oandreeva-nv
Copy link
Collaborator

tests passed on A100: jobs/89630242
I'll merge this PR today. Thank you @SamuraiBUPT for the amazing work and this valuable contribution!

Copy link
Collaborator

@oandreeva-nv oandreeva-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work!

@oandreeva-nv oandreeva-nv merged commit f064eed into triton-inference-server:main Apr 18, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants