From 239864632af6a27e8dc7fae0b270881aaa501aed Mon Sep 17 00:00:00 2001 From: Florian Greinacher Date: Mon, 18 Mar 2024 17:34:17 +0100 Subject: [PATCH] feat: support complex message content for chat completions endpoint Co-authored-by: Lily Liu Co-authored-by: Cyrus Leung --- tests/entrypoints/test_openai_server.py | 19 ++++++++++ vllm/entrypoints/openai/serving_chat.py | 48 ++++++++++++++----------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 68332228ace08..a2a98abe7031c 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +async def test_complex_message_content(server, client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": [{ + "type": + "text", + "text": + "what is 1+1? please provide the result without any other text." + }] + }], + temperature=0, + seed=0) + content = resp.choices[0].message.content + assert content == "2" + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5ed042ef386ea..599f99e56a726 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -55,9 +55,16 @@ def _parse_chat_message_content( if isinstance(content, str): return [ConversationMessage(role=role, content=content)], [] - # To be implemented: https://github.com/vllm-project/vllm/pull/3467 - # To be implemented: https://github.com/vllm-project/vllm/pull/4200 - raise NotImplementedError("Complex input not supported yet") + texts: List[str] = [] + for _, part in enumerate(content): + if part["type"] == "text": + text = part["text"] + + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part['type']}") + + return [ConversationMessage(role=role, content="\n".join(texts))], [] async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -122,11 +129,12 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id) + request, result_generator, request_id, conversation) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + request, raw_request, result_generator, request_id, + conversation) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -139,8 +147,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> AsyncGenerator[str, None]: + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -179,12 +188,10 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, - list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get( + "role") == role: + last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(request.n): @@ -279,9 +286,10 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -322,11 +330,9 @@ async def chat_completion_full_generator( if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] for choice in choices: full_message = last_msg_content + choice.message.content