Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-lora support for Triton vLLM backend #23

Merged
merged 28 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
932e837
add lora support for backend
l1cacheDell Nov 28, 2023
3749535
finish vllm triton lora support
l1cacheDell Nov 29, 2023
d5b35f4
add docs for deploying multi-lora on triton
l1cacheDell Nov 29, 2023
9f30739
update docs
l1cacheDell Nov 29, 2023
5a92b9b
bug fix
l1cacheDell Nov 30, 2023
ba99ce7
Merge branch 'triton-inference-server:main' into main
l1cacheDell Dec 11, 2023
180d42d
CodeReview: remove comment and update docs
l1cacheDell Dec 28, 2023
a88c0a8
bug fix: non-graceful terminate
l1cacheDell Dec 29, 2023
fab6129
update docs to specify container version
l1cacheDell Dec 30, 2023
cfde6ff
Merge branch 'triton-inference-server:main' into main
l1cacheDell Jan 30, 2024
e047eea
update docs for punica kernels compilation
l1cacheDell Jan 30, 2024
546c21f
CodeReview: remove multi_lora.json, update docs and model.py logic
l1cacheDell Jan 31, 2024
7a6334b
update docs: create docker container first
l1cacheDell Jan 31, 2024
1a8eb43
add test stage 1: modify test.sh
l1cacheDell Jan 31, 2024
c45e2fc
resolve merge conflict
l1cacheDell Mar 2, 2024
aebe7a8
remove redundant lines and fix for ci test
l1cacheDell Mar 2, 2024
6dd45bd
update client_lora.py for main branch recent commits
l1cacheDell Mar 3, 2024
9c15bba
modify ci test.sh and docs
l1cacheDell Mar 5, 2024
23c2f70
remove redundant line
l1cacheDell Mar 5, 2024
35ae3c4
fix client_lora process_stream
l1cacheDell Mar 13, 2024
06a5dbd
add ci test for multi-lora
l1cacheDell Mar 15, 2024
3b44ffa
Update src/model.py
l1cacheDell Apr 9, 2024
ed4b977
modify to model.py and ci
l1cacheDell Apr 9, 2024
5f88675
spell check & helper func & copyright & version modify
l1cacheDell Apr 10, 2024
51de171
shebang: file permissions
l1cacheDell Apr 10, 2024
42d74ba
Merge branch 'triton-inference-server:main' into main
l1cacheDell Apr 11, 2024
a2c7e76
code review: changes to docs, client, ci test
l1cacheDell Apr 13, 2024
485e836
modify docs
l1cacheDell Apr 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions docs/llama_multi_lora_tutorial.md
tanmayv25 marked this conversation as resolved.
Show resolved Hide resolved
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Depolying multi-lora vLLM backend in Triton

The idea of multi-lora was proposed recently, for more you can read the paper:
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved

+ [S-LoRA: Serving Thousands of Concurrent LoRA Adapters](https://arxiv.org/abs/2311.03285)
+ [Punica: Multi-Tenant LoRA Serving](https://arxiv.org/abs/2310.18547)

Now the vLLM has supported multi-lora, which integrated the `Punica` feature and related cuda kernels. See this [PR](https://github.com/vllm-project/vllm/pull/1804) for more. (2024-01-24 this PR has been merged into the main branch of vLLM)

The following tutorial demonstrates how to deploy **a LLaMa model** with **multiple loras** on Triton Inference Server using the Triton's [Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) [vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) backend.


## Step 0: Prepare vLLM multi-lora version
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
git clone vllm repository:
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved

```bash
git clone https://github.com/vllm-project/vllm.git
```

l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
and then install vLLM from the source code.

**NOTICE**: To enable multi-lora feature and speed up the inference, developers have integrated punica kernels into the `csrc` directory. To compile the punica kernels, you need to turn this env variable on to allow punica kernels compilation.

By default, the punica kernels will **NOT** be compiled.

All you need to do is to follow the simple step:

```bash
cd vllm
VLLM_INSTALL_PUNICA_KERNELS=1 pip install -e .
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
```

This may take you 5-10 mins.
tanmayv25 marked this conversation as resolved.
Show resolved Hide resolved



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
## Step 1: prepare your weights

To support multi-lora on Triton, you need to manage your file path for **model backbone** and **lora weights** seperately.

A typical weights repository can be as follows:

```
weights
├── backbone
│ └── llama-7b-hf
└── loras
├── alpaca-lora-7b
└── bactrian-x-llama-lora-7b
```

A workspace for vllm, and model weights, LoRA adapter weights is strongly recommended, you can use the command:

```bash
mkdir -p vllm_workspace/weights
cd vllm_workspace
```



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
## Step 2: prepare model repository



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
__2.1 Download the model repository files__

To use Triton, a model repository is needed, for *model path* , *backend configuration* and other information. The vllm backend is implemented based on python backend, and params of vllm are sampled from `model.json`.

To create a triton model repository, you may download the files through these commands:

```bash
# NOTICE: you must first cd to your vllm_workspace path.
cd vllm_workspace

mkdir -p model_repository/vllm_model/1
wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/1/model.json
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
wget -P model_repository/vllm_model/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/config.pbtxt
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
```

The model repository should look like this:

```
model_repository/
└── vllm_model
├── 1
│ └── model.json
└── config.pbtxt
```

---

Now, you have finished the basic deployment, and the file structure should look like this:

```
vllm_workspace
├── weights
│ ├── backbone
│ │ └── llama-7b-hf
│ └── loras
│ ├── alpaca-lora-7b
│ └── bactrian-x-llama-lora-7b
└── model_repository
└── vllm_model
├── 1
│ └── model.json
└── config.pbtxt
```



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
__2.2 Populate `model.json`__

For this tutorial we will use the following set of parameters, specified in the `model.json`.

```json
{
"model":"/vllm_workspace/weights/backbone/llama-7b-hf",
"disable_log_requests": "true",
"gpu_memory_utilization": 0.8,
"tensor_parallel_size": 2,
"block_size": 16,
"enable_lora": "true",
"max_lora_rank": 16
}
```

+ `model`: The path to your model repository
+ `disable_log_requests`: To show logs when launch vllm or not.
+ `gpu_memory_utilization`: The gpu memory allocated for the model weights and vllm *PagedAttention* kv cache manager.
+ `tensor_parallel_size`: The vllm now support the tensor paralism, so you can decide how many gpus you want to use for serving.
+ `block_size`: vLLM kv cache block size.
+ `enable_lora`: If you want to support vllm multi-lora, this should be configured and set `true`.
+ `max_lora_rank`: The maximum of LoRA rank of your lora adapter.

The full set of parameters can be found [here](https://github.com/Yard1/vllm/blob/multi_lora/vllm/engine/arg_utils.py#L11).



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
__2.3 Specify local lora path__

Now (2023.11.29) the [vLLM multi-lora PR](https://github.com/vllm-project/vllm/pull/1804) just supported the inference of **local lora weights applying**, which means that the vllm cannot pull any lora adapter from huggingface. So triton should know where the local loras weights are.

Create a `multi_lora.json` file under `model_repository/vllm_model/1/` path:

```bash
cd model_repository/vllm_model/1
touch multi_lora.json
```

A `multi_lora.json` should look like this:

```json
{
"alpaca": "/vllm_workspace/weights/loras/alpaca-lora-7b",
"bactrian": "/vllm_workspace/weights/loras/bactrian-x-llama-7b-lora"
}
```

The **key** should be the supported lora name, and the **value** should be the specific path in your machine.



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
## Step 3: Start a docker container for triton-vllm serving

**A docker container is strongly recommended for serving**, and this tutorial will only demonstrate how to launch triton in docker env.

First, create a docker container using the NCG built for vllm serving:

```bash
# NOTICE: you must first cd to your vllm_workspace path.
cd vllm_workspace

sudo docker run --gpus all -it --net=host -p 8001:8001 --shm-size=12G \
--ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/vllm_workspace \
-w /vllm_workspace nvcr.io/nvidia/tritonserver:<xx.yy>-vllm-python-py3 \
/bin/bash
```

**NOTICE:** the version of triton docker image should be configurated, here we use `<xx.yy>` to symbolize.

Triton's vLLM container has been introduced starting from 23.10 release, and `lora` support will be added in future release.

Here we recommend you to use `23.12` version, like this:

```bash
sudo docker run --gpus all -it --net=host -p 8001:8001 --shm-size=12G \
--ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/vllm_workspace \
-w /vllm_workspace nvcr.io/nvidia/tritonserver:23.12-vllm-python-py3 \
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
/bin/bash
```

oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
---

For **pre-24.yy containers**, the `model.py` file doesn't support multi-lora feature, so you need to replace that provided in the container `/opt/tritonserver/backends/vllm/model.py` with the most up to date version. Just follow these steps:

Download the `model.py` script from github:

```bash
wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/src/model.py
```

Copy this script to the backend path of triton:

```bash
cp ./model.py /opt/tritonserver/backends/vllm/
```



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
## Step 4: Launch Triton

```bash
tritonserver --model-store ./model_repository
```

After you start Triton you will see output on the console showing the server starting up and loading the model. When you see output like the following, Triton is ready to accept inference requests.

```
I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001
I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000
I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002
```



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
## Step 5: Send a request

A client request script for multi-lora was prepared, downloading the client script from source:
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved

```bash
l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client_lora.py
wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt
```

Try running this script:

```bash
python3 client_lora.py -l alpaca
```



l1cacheDell marked this conversation as resolved.
Show resolved Hide resolved
Loading