diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index 1910f26506611..1922a7d1d6d22 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -51,14 +51,36 @@ the third parameter is the path to the LoRA adapter. Check out `examples/multilora_inference.py `_ for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. -Serving LoRA Adapters +Serving LoRA Adapters (Sample Service) +-------------------------------------- +The sample service entrypoint can be used to serve LoRA modules. To do so, we use +``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server: + +.. code-block:: bash + python -m vllm.entrypoints.api_server \ + --model meta-llama/Llama-2-7b-hf \ + --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ + +This will start a fast-api server that accepts requests. An example is as follows: + +.. code-block:: bash + curl http://localhost:8000/generate -H "Content-Type: application/json" -d '{ + "prompt": "San Francisco is a", + "max_tokens": 7, + "temperature": 1, + "adapter": "sql-lora" + }' + +Note that if the `adapter` parameter is not included, the responses will be from the base model only. +The `adapter` is expected to be the string corresponding to one of the adapter name passed with `lora-modules`. + +Serving LoRA Adapters --------------------- -LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use -``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server: +LoRA adapted models can also be served with the Open-AI compatible vLLM server: .. code-block:: bash - python -m vllm.entrypoints.api_server \ + python -m vllm.entrypoints.openai.api_server \ --model meta-llama/Llama-2-7b-hf \ --enable-lora \ --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ @@ -89,3 +111,4 @@ with its base model: Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other LoRA adapter requests if they were provided and ``max_loras`` is set high enough). + diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index e7af2c6db5e4c..7698c504a120e 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -17,9 +17,13 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.lora.request import LoRARequest +from vllm.entrypoints.openai.api_server import LoRAParserAction + TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() engine = None +adapters = {} @app.get("/health") @@ -34,19 +38,29 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. + - adapter: name of the LoRA adapter to be used. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") + adapter = request_dict.pop("adapter", None) prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + if not adapter: + lora_request = None + elif adapter not in adapters: + raise ValueError(f"{adapter} not a valid adapter in this service") + else: + lora_request = adapters[adapter] + results_generator = engine.generate(prompt, sampling_params, request_id, + lora_request=lora_request, prefix_pos=prefix_pos) # Streaming case @@ -89,11 +103,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--lora-modules", + type=str, + default=None, + nargs='+', + action=LoRAParserAction, + help= + "LoRA module configurations in the format name=path. Multiple modules can be specified." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + adapters = { + lora.name: LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_local_path=lora.local_path, + ) for i, lora in enumerate(args.lora_modules, start=1) + } if args.enable_lora else {} app.root_path = args.root_path uvicorn.run(app,