-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-lora support for Triton vLLM backend #23
Conversation
src/model.py
Outdated
|
||
import numpy as np | ||
import triton_python_backend_utils as pb_utils | ||
from vllm import SamplingParams | ||
from vllm.sampling_params import SamplingParams # MODIFY: bug fix, it won't bother |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please clarify this bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using pip install -e .
to install vLLM from their source code (to test multi-lora version), in this case, using from vllm import SamplingParams
will throw ImportError
:
ImportError: cannot import name 'SamplingParams' from 'vllm' (unknown location)
So I just tried to import the SamplingParams
from the source code directly, to avoid this problem.
I think this modification will not affect either common users or developers, and solve the problem caused by ImportError
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I'll keep this conversation unresolved for history
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing job on this PR! I've noticed that you have a lot o MODIFY
comments in model.py
, is this PR still in progress? Feel free to tag me, when it is done, I'll review
@oandreeva-nv Thank you very much for your thorough review of my code and for providing me with detailed guidance and suggestions. I marked More changes are made to the code and the docs based on your suggestions. Note that vllm multi-lora feature may not come with container 23.12 ( -> vllm 0.2.3). I will keep working on multi-lora support on backend, and fix bug or offer new features. |
2023-12-29 update:
and after fixing, the server will shut down smoothly:
Lastly, hope this PR can get merged. :) @oandreeva-nv |
which model.py we should update? multi-lora feature |
ok right now evertyhing is fine, if I clone main vllm last version, not some branch which you posted in docs |
Everything in this PR is current with the latest vLLM version. I've revised the documentation to include instructions on installing vLLM with punica kernel compilation. |
@SamuraiBUPT , may I also ask to fix some pre-commit issues? The easiest way is to install Some spelling issues should be fixed manually |
@oandreeva-nv Modifications have been made, and I have run pre-commit check locally and fix some issues as well. Leave an issue about sending LoRA error. |
Created pipeline (id: 14152157 ) |
Hi @SamuraiBUPT, may I ask to re-base your PR on top of a main branch? We've recently updated vllm version to 0.4.0 and it required some test changes |
@oandreeva-nv done :) |
re-running CI |
I apologize for my ignorance; I did not fully adhere to the contribution process outlined in the Triton Contribution Guidelines initially. However, I have now submitted a CLA to [email protected]. |
@oandreeva-nv @tanmayv25 Hello, I have addressed the following issues in the new commit:
Regarding the other issues:
By the way, I don't know where our pipeline is. I greatly appreciate you providing the pipeline ID, but could you also provide the location of the pipeline, or the platform it runs on? |
Hi @SamuraiBUPT , thanks for the thorough review! I'll re-run the pipeline and let you know regarding results. At the moment CI/CD process is internal and external contributors do not have an access to it, thus maintainers of the Triton codebase are responsible for making sure CI/CD has passed for all external PRs. I've started a new CI pipeline: |
@SamuraiBUPT Thanks for resolving my questions and excellent contribution! The code changes looks good to me. I will defer to @oandreeva-nv for final approval. We will enable the PUNICA kernels in our container image as a follow-up item. |
HI @oandreeva-nv could you please let me know how it's going? If there are any issues with the CI test results that require bug fixes, I can help resolve them. I appreciate your time. :) |
Hi @SamuraiBUPT , apologies for the delay, our initial tests were running on V100s, but it turned out that v100s are too old. I did some adjustments to the test infrastructure to run on A100. Results should be ready soon. Again, apologies for the delay |
tests passed on A100: jobs/89630242 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work!
Hello, the vllm recently support multi-lora feature, see this PR for more information.
To support request for lora, I implemented the code for lora request in triton backend.
The vLLM scheme
Now (2023.11.29) vllm supports loading local weights to model backbone, (no auto-download from S3/Hugging Face hub), they claimed that they will implement the remote-weights downloading later.
Besides, the LoRA path should be configured from sender, for more you can see there interface.
Triton scheme
Under the circumstances, the triton should act as a transit point for user's request processing and converting.
For users, all they need to do is to provide a target lora name for querying:
Then Triton will check if the target lora is in the
local lora weights
or not. If true, Triton will wrap this target lora information to a new objectLoRARequest
and send to AsyncEngine to execute.See
model.py
for more information.Changes to model repository
The
model.json
needs new args such as"enable_lora": "true"
to support multi-lora in vllm, see EngineArgs for more.The
multi_lora.json
file was created to figure out the specific lora name for each lora adapter.model.json
stores the information of vllm engine initialization, so information of lora weight path cannot be stored in that. I created a new file to manage all local lora weights.It can be configured like this:
Triton backend will read this file when initializing as a repository for local lora weights.
Test
The code passed the tests with and without lora support. Not setting
enable_lora
does not affect the original code.python3 client_lora.py
python3 client_lora.py -l alpaca
For more information about deployment, you can see the
docs/llama_multi_lora_tutorial.md
.