Skip to content

Commit

Permalink
feat: support complex message content for chat completions endpoint
Browse files Browse the repository at this point in the history
Co-authored-by: Lily Liu <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
3 people committed Apr 29, 2024
1 parent df29793 commit 2398646
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
19 changes: 19 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message


async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": [{
"type":
"text",
"text":
"what is 1+1? please provide the result without any other text."
}]
}],
temperature=0,
seed=0)
content = resp.choices[0].message.content
assert content == "2"


async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
Expand Down
48 changes: 27 additions & 21 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,16 @@ def _parse_chat_message_content(
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []

# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise NotImplementedError("Complex input not supported yet")
texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text":
text = part["text"]

texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part['type']}")

return [ConversationMessage(role=role, content="\n".join(texts))], []

async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
Expand Down Expand Up @@ -122,11 +129,12 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
request, result_generator, request_id, conversation)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id,
conversation)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -139,8 +147,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> AsyncGenerator[str, None]:
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
Expand Down Expand Up @@ -179,12 +188,10 @@ async def chat_completion_stream_generator(
# last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages,
list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]

if last_msg_content:
for i in range(request.n):
Expand Down Expand Up @@ -279,9 +286,10 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.served_model_names[0]
created_time = int(time.time())
Expand Down Expand Up @@ -322,11 +330,9 @@ async def chat_completion_full_generator(

if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]

for choice in choices:
full_message = last_msg_content + choice.message.content
Expand Down

0 comments on commit 2398646

Please sign in to comment.