diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index c358e23b6a37a..6687929c0bebe 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -11,6 +11,5 @@ pydantic >= 2.8 torch py-cpuinfo transformers -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args mistral_common >= 1.3.4 openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index f08773fe59d92..b3821ebdfceca 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -107,3 +107,55 @@ The following is an example request "max_tokens": 7, "temperature": 0 }' | jq + + +Dynamically serving LoRA Adapters +--------------------------------- + +In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading +LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility +to change models on-the-fly is needed. + +Note: Enabling this feature in production environments is risky as user may participate model adapter management. + +To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` +is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active. + +.. code-block:: bash + + export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True + + +Loading a LoRA Adapter: + +To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary +details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter. + +Example request to load a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/load_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter", + "lora_path": "/path/to/sql-lora-adapter" + }' + +Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter +cannot be found or loaded, an appropriate error message will be returned. + +Unloading a LoRA Adapter: + +To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint +with the name or ID of the adapter to be unloaded. + +Example request to unload a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/unload_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter" + }' diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 35eabf079964a..9f5727ecd0406 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -50,7 +50,7 @@ def zephyr_lora_files(): @pytest.mark.skip_global_cleanup def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): lora_request = [ - LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files) + LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) for idx in range(len(PROMPTS)) ] # Multiple SamplingParams should be matched with each prompt diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py new file mode 100644 index 0000000000000..325bc03434287 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -0,0 +1,107 @@ +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from vllm.config import ModelConfig +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.serving_engine import OpenAIServing + +MODEL_NAME = "meta-llama/Llama-2-7b" +LORA_LOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_UNLOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' removed successfully.") + + +async def _async_serving_engine_init(): + mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_model_config = MagicMock(spec=ModelConfig) + # Set the max_model_len attribute to avoid missing attribute + mock_model_config.max_model_len = 2048 + + serving_engine = OpenAIServing(mock_engine_client, + mock_model_config, + served_model_names=[MODEL_NAME], + lora_modules=None, + prompt_adapters=None, + request_logger=None) + return serving_engine + + +@pytest.mark.asyncio +async def test_load_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter", + lora_path="/path/to/adapter2") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert len(serving_engine.lora_requests) == 1 + assert serving_engine.lora_requests[0].lora_name == "adapter" + + +@pytest.mark.asyncio +async def test_load_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="", lora_path="") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_load_lora_adapter_duplicate(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 1 + + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + assert len(serving_engine.lora_requests) == 1 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert len(serving_engine.lora_requests) == 1 + + request = UnloadLoraAdapterRequest(lora_name="adapter1") + response = await serving_engine.unload_lora_adapter(request) + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 0 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_not_found(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7632e8aa5e32e..a7dded8dc5f97 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -35,11 +35,13 @@ DetokenizeResponse, EmbeddingRequest, EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, TokenizeRequest, - TokenizeResponse) -# yapf: enable + TokenizeResponse, + UnloadLoraAdapterRequest) from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -340,6 +342,40 @@ async def stop_profile(): return Response(status_code=200) +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest): + response = await openai_serving_chat.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest): + response = await openai_serving_chat.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(router) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0954b81595ef5..002c18f473a4d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -777,3 +777,13 @@ class DetokenizeRequest(OpenAIBaseModel): class DetokenizeResponse(OpenAIBaseModel): prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 26e91e7cc94dd..ac74527441cd9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,11 +16,13 @@ CompletionRequest, DetokenizeRequest, EmbeddingRequest, ErrorResponse, + LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, TokenizeChatRequest, TokenizeCompletionRequest, - TokenizeRequest) + TokenizeRequest, + UnloadLoraAdapterRequest) # yapf: enable from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -32,6 +34,7 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AtomicCounter logger = init_logger(__name__) @@ -78,6 +81,7 @@ def __init__( self.served_model_names = served_model_names + self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ @@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token return tokenizer.decode(token_id) + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return self.create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return self.create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." diff --git a/vllm/envs.py b/vllm/envs.py index 3c6b6adff82fc..ed45047e9f8fc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -61,6 +61,7 @@ VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False def get_default_cache_root(): @@ -409,6 +410,12 @@ def get_default_config_root(): # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition diff --git a/vllm/lora/request.py b/vllm/lora/request.py index d770da4f2407d..47a59d80d3a45 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,7 +28,6 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None - __hash__ = AdapterRequest.__hash__ def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: @@ -75,3 +74,21 @@ def local_path(self, value): DeprecationWarning, stacklevel=2) self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/utils.py b/vllm/utils.py index 657a3ecef696d..a22081ebe8df0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1224,3 +1224,28 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def supports_dynamo() -> bool: base_torch_version = Version(Version(torch.__version__).base_version) return base_torch_version >= Version("2.4.0") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value