Skip to content

Commit

Permalink
Support load and unload LoRA in api server
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Jul 19, 2024
1 parent 6366efc commit 6f4cefb
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 3 deletions.
34 changes: 33 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest,
TokenizeResponse)
TokenizeResponse,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
Expand Down Expand Up @@ -164,6 +166,36 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest):
response = await openai_serving_chat.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200)


@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
response = await openai_serving_chat.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200)


def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,13 @@ class DetokenizeRequest(OpenAIBaseModel):

class DetokenizeResponse(OpenAIBaseModel):
prompt: str


class LoadLoraAdapterRequest(BaseModel):
lora_name: str
lora_path: str


class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: int
76 changes: 75 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission, TokenizeRequest)
ModelPermission, TokenizeRequest,
UnloadLoraAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -245,3 +247,75 @@ def _get_decoded_token(logprob: Logprob, token_id: int,
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_path' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret

lora_name, lora_path = request.lora_name, request.lora_path
unique_id = len(self.lora_requests) + 1
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."

async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret

lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
19 changes: 18 additions & 1 deletion vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class LoRARequest(AdapterRequest):
lora_int_id: int
lora_local_path: str
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__

@property
def adapter_id(self):
Expand All @@ -35,3 +34,21 @@ def name(self):
@property
def local_path(self):
return self.lora_local_path

def __eq__(self, value: object) -> bool:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return isinstance(value,
self.__class__) and self.lora_name == value.lora_name

def __hash__(self) -> int:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return hash(self.lora_name)

0 comments on commit 6f4cefb

Please sign in to comment.