Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] openAI entrypoint dynamic adapter load #3850

15 changes: 15 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,18 @@ The following is an example request
"max_tokens": 7,
"temperature": 0
}' | jq


Alternatively, the request can specify a LoRA adapter to load dynamically from the server's local disk storage:

.. code-block:: bash

curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "sql-lora",
"prompt": "San Francisco is a",
"max_tokens": 7,
"temperature": 0,
"lora_request": {"lora_name":"sql-lora","lora_local_path":"/data/adapters/sql-lora"}
}' | jq
15 changes: 14 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid
from vllm.lora.request import LoRARequest

# torch is mocked during docs generation,
# so we have to provide the values as literals
Expand Down Expand Up @@ -218,6 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
lora_request: Optional[dict] = Field(default_factory=dict)

guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
Expand All @@ -232,6 +235,11 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_lora_params(self) -> Union[LoRARequest, None]:
if not self.lora_request:
return None
return LoRARequest(**self.lora_request)

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
Expand Down Expand Up @@ -403,6 +411,7 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
lora_request: Optional[dict] = Field(default_factory=dict)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
Expand All @@ -417,14 +426,18 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_lora_params(self) -> Union[LoRARequest, None]:
if not self.lora_request:
return None
return LoRARequest(**self.lora_request)

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = get_logits_processors(
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
import os

from pydantic import Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -169,6 +170,8 @@ async def _check_model(
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
elif request.lora_request and os.path.exists(request.lora_request.get("lora_local_path")):
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
Expand All @@ -188,6 +191,13 @@ def _maybe_get_adapters(
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora, None
if request.lora_request and os.path.exists(request.lora_request.get("lora_local_path")):
new_lora = LoRARequest(
lora_name=request.model,
lora_local_path=request.lora_request.get("lora_local_path")
)
self.lora_requests.append(new_lora)
return new_lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
Expand Down
20 changes: 18 additions & 2 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import warnings
from dataclasses import dataclass, field
import warnings
from typing import Optional
import hashlib

from vllm.adapter_commons.request import AdapterRequest


def positive_hash_sha256(input_string):
"""
function to generate positive hash from input string, which is used to identify the model variant for lora
sha-256 is used to keep it consistent between python versions and the sheets addon
"""
return int(hashlib.sha256(input_string.encode('utf-8')).hexdigest(), 16) % (2 ** 63)


@dataclass
class LoRARequest(AdapterRequest):
"""
Expand All @@ -20,7 +29,7 @@ class LoRARequest(AdapterRequest):
"""

lora_name: str
lora_int_id: int
lora_int_id: Optional[int] = 0
lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
long_lora_max_len: Optional[int] = None
Expand All @@ -37,6 +46,13 @@ def __post_init__(self):
if not self.lora_path:
self.lora_path = self.lora_local_path or ""

# if no int_id was given, use the name hash as id
if not self.lora_int_id:
self.lora_int_id = positive_hash_sha256(self.lora_name)
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")

# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"

Expand Down
Loading