Skip to content

Commit

Permalink
fix: multi-lora with sample api-server
Browse files Browse the repository at this point in the history
  • Loading branch information
ganesh-dataminr committed Feb 27, 2024
1 parent 71bcaf9 commit 5166259
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
31 changes: 27 additions & 4 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,36 @@ the third parameter is the path to the LoRA adapter.
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.

Serving LoRA Adapters
Serving LoRA Adapters (Sample Service)
--------------------------------------
The sample service entrypoint can be used to serve LoRA modules. To do so, we use
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:

.. code-block:: bash
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-hf \
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
This will start a fast-api server that accepts requests. An example is as follows:

.. code-block:: bash
curl http://localhost:8000/generate -H "Content-Type: application/json" -d '{
"prompt": "San Francisco is a",
"max_tokens": 7,
"temperature": 1,
"adapter": "sql-lora"
}'
Note that if the `adapter` parameter is not included, the responses will be from the base model only.
The `adapter` is expected to be the string corresponding to one of the adapter name passed with `lora-modules`.

Serving LoRA Adapters
---------------------
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
LoRA adapted models can also be served with the Open-AI compatible vLLM server:

.. code-block:: bash
python -m vllm.entrypoints.api_server \
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
Expand Down Expand Up @@ -89,3 +111,4 @@ with its base model:
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).

30 changes: 30 additions & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from vllm.lora.request import LoRARequest
from vllm.entrypoints.openai.api_server import LoRAParserAction

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
adapters = {}


@app.get("/health")
Expand All @@ -34,19 +38,29 @@ async def generate(request: Request) -> Response:
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- adapter: name of the LoRA adapter to be used.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
adapter = request_dict.pop("adapter", None)
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

if not adapter:
lora_request = None
elif adapter not in adapters:
raise ValueError(f"{adapter} not a valid adapter in this service")
else:
lora_request = adapters[adapter]

results_generator = engine.generate(prompt,
sampling_params,
request_id,
lora_request=lora_request,
prefix_pos=prefix_pos)

# Streaming case
Expand Down Expand Up @@ -89,11 +103,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
adapters = {
lora.name: LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(args.lora_modules, start=1)
} if args.enable_lora else {}

app.root_path = args.root_path
uvicorn.run(app,
Expand Down

0 comments on commit 5166259

Please sign in to comment.