From ec1006514b6d814aa1079e55b61b31ba11ee952c Mon Sep 17 00:00:00 2001 From: Andy <37781802+aandyw@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:44:11 -0400 Subject: [PATCH] [Frontend] Batch inference for llm.chat() API (#8648) Co-authored-by: Cyrus Leung Co-authored-by: Cyrus Leung Co-authored-by: Roger Wang Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Signed-off-by: Sumit Dubey --- examples/offline_inference_chat.py | 27 +++++++++ tests/entrypoints/llm/test_generate.py | 35 +++++++++++ vllm/entrypoints/llm.py | 82 +++++++++++++++----------- 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index c2020724c72fe..8814f4d7bef0d 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -39,6 +39,33 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=True) +print_outputs(outputs) + # A chat template can be optionally supplied. # If not, the model will use its default chat template. diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index ef34bebbb0f8c..cd989225e2483 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -162,6 +162,41 @@ def test_chat(): assert len(outputs) == 1 +def test_multi_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca80dedd29ebd..cd10eda8c212c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -485,7 +485,8 @@ def beam_search( def chat( self, - messages: List[ChatCompletionMessageParam], + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -505,8 +506,9 @@ def chat( to the OpenAI API. Args: - messages: A single conversation represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -523,42 +525,56 @@ def chat( A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """ + list_of_messages: List[List[ChatCompletionMessageParam]] - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - - conversation, mm_data = parse_chat_messages(messages, model_config, - tokenizer) - - prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( - tokenizer, - messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = messages else: - prompt = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # messages is List[...] + list_of_messages = [messages] + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversation, mm_data = parse_chat_messages( + msgs, model_config, tokenizer) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) - inputs: PromptInputs - if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) - else: - inputs = TextPrompt(prompt=prompt) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data - if mm_data is not None: - inputs["multi_modal_data"] = mm_data + prompts.append(prompt) return self.generate( - inputs, + prompts, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request,