Skip to content

Commit

Permalink
[Frontend] Batch inference for llm.chat() API (vllm-project#8648)
Browse files Browse the repository at this point in the history
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
5 people authored and dtrifiro committed Sep 27, 2024
1 parent fcf6bf5 commit 69d29e6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 33 deletions.
27 changes: 27 additions & 0 deletions examples/offline_inference_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ def print_outputs(outputs):
use_tqdm=False)
print_outputs(outputs)

# You can run batch inference with llm.chat API
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
conversations = [conversation for _ in range(10)]

# We turn on tqdm progress bar to verify it's indeed running batch inference
outputs = llm.chat(messages=conversations,
sampling_params=sampling_params,
use_tqdm=True)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

Expand Down
35 changes: 35 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,41 @@ def test_chat():
assert len(outputs) == 1


def test_multi_chat():

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")

prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."

conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]

conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]

messages = [conversation1, conversation2]

outputs = llm.chat(messages)
assert len(outputs) == 2


@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
Expand Down
82 changes: 49 additions & 33 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def beam_search(

def chat(
self,
messages: List[ChatCompletionMessageParam],
messages: Union[List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
Expand All @@ -505,8 +506,9 @@ def chat(
to the OpenAI API.
Args:
messages: A single conversation represented as a list of messages.
Each message is a dictionary with 'role' and 'content' keys.
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
Expand All @@ -523,42 +525,56 @@ def chat(
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages: List[List[ChatCompletionMessageParam]]

tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)

prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = messages
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
# messages is List[...]
list_of_messages = [messages]

prompts: List[Union[TokensPrompt, TextPrompt]] = []

for msgs in list_of_messages:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)

prompt_data: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)

prompt: Union[TokensPrompt, TextPrompt]
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
else:
prompt = TextPrompt(prompt=prompt_data)

inputs: PromptInputs
if is_list_of(prompt, int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data

if mm_data is not None:
inputs["multi_modal_data"] = mm_data
prompts.append(prompt)

return self.generate(
inputs,
prompts,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
Expand Down

0 comments on commit 69d29e6

Please sign in to comment.