Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-lora support for Triton vLLM backend #23

Merged
merged 28 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
932e837
add lora support for backend
l1cacheDell Nov 28, 2023
3749535
finish vllm triton lora support
l1cacheDell Nov 29, 2023
d5b35f4
add docs for deploying multi-lora on triton
l1cacheDell Nov 29, 2023
9f30739
update docs
l1cacheDell Nov 29, 2023
5a92b9b
bug fix
l1cacheDell Nov 30, 2023
ba99ce7
Merge branch 'triton-inference-server:main' into main
l1cacheDell Dec 11, 2023
180d42d
CodeReview: remove comment and update docs
l1cacheDell Dec 28, 2023
a88c0a8
bug fix: non-graceful terminate
l1cacheDell Dec 29, 2023
fab6129
update docs to specify container version
l1cacheDell Dec 30, 2023
cfde6ff
Merge branch 'triton-inference-server:main' into main
l1cacheDell Jan 30, 2024
e047eea
update docs for punica kernels compilation
l1cacheDell Jan 30, 2024
546c21f
CodeReview: remove multi_lora.json, update docs and model.py logic
l1cacheDell Jan 31, 2024
7a6334b
update docs: create docker container first
l1cacheDell Jan 31, 2024
1a8eb43
add test stage 1: modify test.sh
l1cacheDell Jan 31, 2024
c45e2fc
resolve merge conflict
l1cacheDell Mar 2, 2024
aebe7a8
remove redundant lines and fix for ci test
l1cacheDell Mar 2, 2024
6dd45bd
update client_lora.py for main branch recent commits
l1cacheDell Mar 3, 2024
9c15bba
modify ci test.sh and docs
l1cacheDell Mar 5, 2024
23c2f70
remove redundant line
l1cacheDell Mar 5, 2024
35ae3c4
fix client_lora process_stream
l1cacheDell Mar 13, 2024
06a5dbd
add ci test for multi-lora
l1cacheDell Mar 15, 2024
3b44ffa
Update src/model.py
l1cacheDell Apr 9, 2024
ed4b977
modify to model.py and ci
l1cacheDell Apr 9, 2024
5f88675
spell check & helper func & copyright & version modify
l1cacheDell Apr 10, 2024
51de171
shebang: file permissions
l1cacheDell Apr 10, 2024
42d74ba
Merge branch 'triton-inference-server:main' into main
l1cacheDell Apr 11, 2024
a2c7e76
code review: changes to docs, client, ci test
l1cacheDell Apr 13, 2024
485e836
modify docs
l1cacheDell Apr 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ci/L0_backend_vllm/multi_lora/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from huggingface_hub import snapshot_download

if __name__ == "__main__":
# download lora weight alpaca
snapshot_download(
repo_id="tloen/alpaca-lora-7b",
local_dir="./weights/loras/alpaca",
max_workers=8,
)
# download lora weight WizardLM
snapshot_download(
repo_id="winddude/wizardLM-LlaMA-LoRA-7B",
local_dir="./weights/loras/WizardLM",
max_workers=8,
)
# download llama-7b-hf
snapshot_download(
repo_id="luodian/llama-7b-hf",
local_dir="./weights/backbone/llama-7b-hf",
max_workers=8,
)
181 changes: 181 additions & 0 deletions ci/L0_backend_vllm/multi_lora/multi_lora_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys
import unittest
from functools import partial
from typing import List

import tritonclient.grpc as grpcclient
from tritonclient.utils import *

sys.path.append("../../common")
from test_util import AsyncTestResultCollector, UserData, callback, create_vllm_request

PROMPTS = ["Instruct: What do you think of Computer Science?\nOutput:"]
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}

server_enable_lora = True


class VLLMTritonLoraTest(AsyncTestResultCollector):
def setUp(self):
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
self.vllm_model_name = "vllm_llama_multi_lora"

def _test_vllm_model(
self,
prompts: List[str],
sampling_parameters,
lora_name: List[str],
server_enable_lora=True,
stream=False,
exclude_input_in_output=None,
expected_output=None,
):
assert len(prompts) == len(
lora_name
), "The number of prompts and lora names should be the same"
user_data = UserData()
number_of_vllm_reqs = len(prompts)

self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(number_of_vllm_reqs):
lora = lora_name[i] if lora_name else None
sam_para_copy = sampling_parameters.copy()
if lora is not None:
sam_para_copy["lora_name"] = lora
request_data = create_vllm_request(
prompts[i],
i,
stream,
sam_para_copy,
self.vllm_model_name,
exclude_input_in_output=exclude_input_in_output,
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
request_id=request_data["request_id"],
inputs=request_data["inputs"],
outputs=request_data["outputs"],
parameters=sampling_parameters,
)

for i in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
if server_enable_lora:
self.assertEqual(
str(result.message()),
f"LoRA {lora_name[i]} is not supported, we currently support ['alpaca', 'WizardLM']",
"InferenceServerException",
)
else:
self.assertEqual(
str(result.message()),
"LoRA feature is not enabled.",
"InferenceServerException",
)
self.triton_client.stop_stream()
return

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")
if expected_output is not None:
self.assertEqual(
output,
expected_output[i],
'Actual and expected outputs do not match.\n \
Expected "{}" \n Actual:"{}"'.format(
output, expected_output[i]
),
)

self.triton_client.stop_stream()

def test_multi_lora_requests(self):
self.triton_client.load_model(self.vllm_model_name)
sampling_parameters = {"temperature": "0", "top_p": "1"}
# make two requests separately to avoid the different arrival of response answers
prompt_1 = ["Instruct: What do you think of Computer Science?\nOutput:"]
lora_1 = ["alpaca"]
expected_output = [
b" I think Computer Science is an interesting and exciting field. It is constantly evol"
]
self._test_vllm_model(
prompt_1,
sampling_parameters,
lora_name=lora_1,
server_enable_lora=server_enable_lora,
stream=False,
exclude_input_in_output=True,
expected_output=expected_output,
)

prompt_2 = ["Instruct: Tell me more about soccer\nOutput:"]
lora_2 = ["WizardLM"]
expected_output = [
b" Soccer is a team sport played between two teams of eleven players each. The object"
]
self._test_vllm_model(
prompt_2,
sampling_parameters,
lora_name=lora_2,
server_enable_lora=server_enable_lora,
stream=False,
exclude_input_in_output=True,
expected_output=expected_output,
)
self.triton_client.unload_model(self.vllm_model_name)

def test_none_exist_lora(self):
self.triton_client.load_model(self.vllm_model_name)
prompts = [
"Instruct: What is the capital city of France?\nOutput:",
]
loras = ["bactrian"]
sampling_parameters = {"temperature": "0", "top_p": "1"}
self._test_vllm_model(
prompts,
sampling_parameters,
lora_name=loras,
server_enable_lora=server_enable_lora,
stream=False,
exclude_input_in_output=True,
expected_output=None, # this request will lead to lora not supported error, so there is no expected output
)
self.triton_client.unload_model(self.vllm_model_name)

def tearDown(self):
self.triton_client.close()


if __name__ == "__main__":
server_enable_lora = os.environ.get("SERVER_ENABLE_LORA", "false").lower() == "true"

unittest.main()
Loading