Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Chat-based Embeddings API #9759

Merged
merged 46 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
1b91750
Initial implementation
DarkLight1337 Oct 28, 2024
61e0fcf
Update docs
DarkLight1337 Oct 28, 2024
c62be47
Cleanup
DarkLight1337 Oct 28, 2024
cc999b1
Consolidate and make code consistent
DarkLight1337 Oct 28, 2024
9ed87c1
Remove useless statement
DarkLight1337 Oct 28, 2024
efa7c6f
Rename back
DarkLight1337 Oct 28, 2024
ab9297e
Factor out common code
DarkLight1337 Oct 28, 2024
5a4f271
Reinstate truncate_prompt_tokens check
DarkLight1337 Oct 29, 2024
4a969b4
Rename
DarkLight1337 Oct 29, 2024
279b9ce
Fix
DarkLight1337 Oct 29, 2024
7de803f
Remove unused code
DarkLight1337 Oct 29, 2024
c1ef363
Migrate tokenization API
DarkLight1337 Oct 29, 2024
a10fa85
Some fixes
DarkLight1337 Oct 29, 2024
89e0710
format
DarkLight1337 Oct 29, 2024
81b94de
remoev unused imports
DarkLight1337 Oct 29, 2024
a79d3b2
Migrate chat and completion APIs
DarkLight1337 Oct 29, 2024
8b950dd
Factor out trace headers code
DarkLight1337 Oct 29, 2024
2c91855
Merge branch 'main' into chat-embeddings-api
DarkLight1337 Oct 29, 2024
f5e72ff
Clean
DarkLight1337 Oct 29, 2024
9cd1ac3
More precise error handling
DarkLight1337 Oct 29, 2024
d775150
Add and update tests
DarkLight1337 Oct 29, 2024
f2b5846
Cleanup
DarkLight1337 Oct 29, 2024
4a25806
Fix tests
DarkLight1337 Oct 29, 2024
bbcfc6a
Update docs
DarkLight1337 Oct 29, 2024
b6820b7
Add docs
DarkLight1337 Oct 29, 2024
fed887a
Fix doc failure
DarkLight1337 Oct 29, 2024
1774b27
Mock out starlette
DarkLight1337 Oct 29, 2024
c94aa93
Try fix docs
DarkLight1337 Oct 29, 2024
e2ecbcd
Cleanup docs
DarkLight1337 Oct 29, 2024
fbbd8b1
Fix newlines
DarkLight1337 Oct 29, 2024
50ad3aa
Reword
DarkLight1337 Oct 29, 2024
9c1df21
Fix
DarkLight1337 Oct 29, 2024
8049030
Update
DarkLight1337 Oct 29, 2024
a387845
Update
DarkLight1337 Oct 29, 2024
d80ec7e
Update
DarkLight1337 Oct 29, 2024
ea5fd96
format
DarkLight1337 Oct 29, 2024
b05ede6
Convert to tip
DarkLight1337 Oct 29, 2024
dba9806
newline
DarkLight1337 Oct 29, 2024
557c9ef
Fix missing client
DarkLight1337 Oct 30, 2024
8c8ee96
Merge branch 'main' into chat-embeddings-api
DarkLight1337 Oct 31, 2024
c3ba030
Merge branch 'main' into chat-embeddings-api
DarkLight1337 Oct 31, 2024
46f316f
Optionally initialize request handlers
DarkLight1337 Nov 1, 2024
1179f66
Update tip
DarkLight1337 Nov 1, 2024
eb4b235
Update tests
DarkLight1337 Nov 1, 2024
bf46a16
format
DarkLight1337 Nov 1, 2024
7f188f9
Rename
DarkLight1337 Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep

A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__.

OpenAI Chat API with vLLM
~~~~~~~~~~~~~~~~~~~~~~~~~~
OpenAI Chat Completions API with vLLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

vLLM is designed to also support the OpenAI Chat API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
vLLM is designed to also support the OpenAI Chat Completions API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.

You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ completion = client.chat.completions.create(
)
```

### Extra Parameters for Chat API
### Extra Parameters for Chat Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
Expand Down
13 changes: 4 additions & 9 deletions tests/entrypoints/openai/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from http import HTTPStatus
from typing import List

import openai
import pytest
import pytest_asyncio
import requests
Expand Down Expand Up @@ -83,10 +82,8 @@ async def client(server):
indirect=True,
)
@pytest.mark.asyncio
async def test_show_version(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")

response = requests.get(base_url + "/version")
async def test_show_version(server: RemoteOpenAIServer):
response = requests.get(server.url_for("version"))
response.raise_for_status()

assert response.json() == {"version": VLLM_VERSION}
Expand All @@ -102,9 +99,7 @@ async def test_show_version(client: openai.AsyncOpenAI):
indirect=True,
)
@pytest.mark.asyncio
async def test_check_health(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")

response = requests.get(base_url + "/health")
async def test_check_health(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"))

assert response.status_code == HTTPStatus.OK
13 changes: 4 additions & 9 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import time
from http import HTTPStatus

import openai
import pytest
import pytest_asyncio
import requests
Expand Down Expand Up @@ -79,17 +78,15 @@ async def client(server):


@pytest.mark.asyncio
async def test_metrics_counts(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")

async def test_metrics_counts(server: RemoteOpenAIServer):
for _ in range(_NUM_REQUESTS):
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)

response = requests.get(base_url + "/metrics")
response = requests.get(server.url_for("metrics"))
print(response.text)
assert response.status_code == HTTPStatus.OK

Expand Down Expand Up @@ -170,16 +167,14 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):


@pytest.mark.asyncio
async def test_metrics_exist(client: openai.AsyncOpenAI):
base_url = str(client.base_url)[:-3].strip("/")

async def test_metrics_exist(server: RemoteOpenAIServer):
# sending a request triggers the metrics to be logged.
await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

response = requests.get(base_url + "/metrics")
response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK

for metric in EXPECTED_METRICS:
Expand Down
32 changes: 18 additions & 14 deletions tests/entrypoints/openai/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
Expand Down Expand Up @@ -55,17 +54,19 @@ async def client(server):
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
async def test_tokenize_completions(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
response = requests.post(server.url_for("tokenize"),
json={
"add_special_tokens": add_special,
"model": model_name,
Expand All @@ -86,9 +87,11 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
async def test_tokenize_chat(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

Expand Down Expand Up @@ -121,7 +124,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
response = requests.post(server.url_for("tokenize"),
json={
"add_generation_prompt":
add_generation,
Expand All @@ -146,17 +149,18 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
async def test_detokenize(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False)

print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize",
response = requests.post(server.url_for("detokenize"),
json={
"model": model_name,
"tokens": tokens
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def init_app_state(
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=args.chat_template,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
Expand Down
86 changes: 84 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def validate_stream_options(cls, data):
return data


class EmbeddingRequest(OpenAIBaseModel):
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
Expand All @@ -717,6 +717,12 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: end-embedding-pooling-params

# doc: begin-embedding-extra-params
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
priority: int = Field(
default=0,
description=(
Expand All @@ -730,6 +736,82 @@ def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)


class EmbeddingChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]

encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None

# doc: begin-chat-embedding-extra-params
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-chat-embedding-extra-params

@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data

def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]


class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
Expand Down Expand Up @@ -792,7 +874,7 @@ class EmbeddingResponseData(OpenAIBaseModel):


class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ async def main(args):
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=None,
)

tracker = BatchProgressTracker()
Expand Down
22 changes: 9 additions & 13 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation
from vllm.utils import is_list_of, iterate_with_cancellation

logger = init_logger(__name__)

Expand Down Expand Up @@ -94,12 +94,12 @@ async def create_chat_completion(
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""Completion API similar to OpenAI's API.
"""
Chat Completion API similar to OpenAI's API.

See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.

Chat Completion API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand Down Expand Up @@ -152,14 +152,10 @@ async def create_chat_completion(
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e))

try:
mm_data = await mm_data_future
except Exception as e:
logger.exception("Error in loading multi-modal data")
logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e))

# validation for OpenAI tools
Expand All @@ -176,7 +172,7 @@ async def create_chat_completion(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")

request_id = f"chat-{request.request_id}"
request_id = f"chatcmpl-{request.request_id}"

request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
Expand All @@ -196,9 +192,9 @@ async def create_chat_completion(
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
# For MistralTokenizer
assert is_list_of(prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)

Expand Down
8 changes: 1 addition & 7 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Union, cast

Expand Down Expand Up @@ -37,11 +36,6 @@

logger = init_logger(__name__)

TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]


class OpenAIServingCompletion(OpenAIServing):

Expand Down
Loading