diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 3083538e1dcff..bd3940f650e06 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -107,6 +107,7 @@ OpenAI-Compatible Server ------------------------ vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. +By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. Start the server: @@ -122,7 +123,13 @@ Use model from www.modelscope.cn $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \ $ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code -By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_ and `create completion `_ endpoints. We are actively adding support for more endpoints. +By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: + +.. code-block:: console + + $ python -m vllm.entrypoints.openai.api_server \ + $ --model facebook/opt-125m \ + $ --chat-template ./examples/template_chatml.json This server can be queried in the same format as OpenAI API. For example, list the models: @@ -130,6 +137,9 @@ This server can be queried in the same format as OpenAI API. For example, list t $ curl http://localhost:8000/v1/models +Using OpenAI Completions API with vLLM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + Query the model with input prompts: .. code-block:: console @@ -156,3 +166,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep print("Completion result:", completion) For a more detailed client example, refer to `examples/openai_completion_client.py `_. + +Using OpenAI Chat API with vLLM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. + +Querying the model using OpenAI Chat API: + +You can use the `create chat completion `_ endpoint to communicate with the model in a chat-like interface: + +.. code-block:: console + + $ curl http://localhost:8000/v1/chat/completions \ + $ -H "Content-Type: application/json" \ + $ -d '{ + $ "model": "facebook/opt-125m", + $ "messages": [ + $ {"role": "system", "content": "You are a helpful assistant."}, + $ {"role": "user", "content": "Who won the world series in 2020?"} + $ ] + $ }' + +Python Client Example: + +Using the `openai` python package, you can also communicate with the model in a chat-like manner: + +.. code-block:: python + + import openai + # Set OpenAI's API key and API base to use vLLM's API server. + openai.api_key = "EMPTY" + openai.api_base = "http://localhost:8000/v1" + chat_response = openai.ChatCompletion.create( + model="facebook/opt-125m", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + ) + print("Chat response:", chat_response) + +For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation. diff --git a/examples/template_alpaca.jinja b/examples/template_alpaca.jinja new file mode 100644 index 0000000000000..60667acc3ef96 --- /dev/null +++ b/examples/template_alpaca.jinja @@ -0,0 +1,29 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{% for message in messages %} +{% if message['role'] == 'user' %} +### Instruction: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'assistant' %} +### Response: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'user_context' %} +### Input: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +### Response: +{% endif %} \ No newline at end of file diff --git a/examples/template_chatml.jinja b/examples/template_chatml.jinja new file mode 100644 index 0000000000000..4844e681e1b6c --- /dev/null +++ b/examples/template_chatml.jinja @@ -0,0 +1,2 @@ +{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} \ No newline at end of file diff --git a/examples/template_inkbot.jinja b/examples/template_inkbot.jinja new file mode 100644 index 0000000000000..33a817454df36 --- /dev/null +++ b/examples/template_inkbot.jinja @@ -0,0 +1,30 @@ +<#meta#> +- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }} +- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }} +<#system#> +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} +<#chat#> +{% for message in messages %} +{% if message['role'] == 'user' %} +<#user#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% elif message['role'] == 'assistant' %} +<#bot#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% elif message['role'] == 'user_context' %} +<#user_context#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +<#bot#> +{% endif %} \ No newline at end of file diff --git a/tests/async_engine/test_openai_server.py b/tests/async_engine/test_openai_server.py new file mode 100644 index 0000000000000..a61ff7e84ca66 --- /dev/null +++ b/tests/async_engine/test_openai_server.py @@ -0,0 +1,119 @@ +from argparse import Namespace +from dataclasses import dataclass + +import pytest +from fastapi.testclient import TestClient + +from vllm.entrypoints.openai.api_server import * + +# Define models, templates, and their corresponding expected outputs +MODEL_TEMPLATE_GENERATON_OUTPUT = [ + ("facebook/opt-125m", None, True, + "HelloHi there!What is the capital of"), + ("facebook/opt-125m", None, False, + "HelloHi there!What is the capital of"), + ("facebook/opt-125m", "../../examples/template_chatml.jinja", True, + """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of<|im_end|> +<|im_start|>assistant +"""), + ("facebook/opt-125m", "../../examples/template_chatml.jinja", False, + """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of""") +] + +TEST_MESSAGES = [ + { + 'role': 'user', + 'content': 'Hello' + }, + { + 'role': 'assistant', + 'content': 'Hi there!' + }, + { + 'role': 'user', + 'content': 'What is the capital of' + }, +] +client = TestClient(app) + + +@dataclass +class MockTokenizer: + chat_template = None + + +def test_load_chat_template(): + # Testing chatml template + template = "../../examples/template_chatml.jinja" + mock_args = Namespace(chat_template=template) + tokenizer = MockTokenizer() + + # Call the function with the mocked args + load_chat_template(mock_args, tokenizer) + + template_content = tokenizer.chat_template + + # Test assertions + assert template_content is not None + # Hard coded value for template_chatml.jinja + assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" + + +def test_no_load_chat_template(): + # Testing chatml template + template = "../../examples/does_not_exist" + mock_args = Namespace(chat_template=template) + tokenizer = MockTokenizer() + + # Call the function with the mocked args + load_chat_template(mock_args, tokenizer=tokenizer) + template_content = tokenizer.chat_template + + # Test assertions + assert template_content is not None + # Hard coded value for template_chatml.jinja + assert template_content == """../../examples/does_not_exist""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model,template,add_generation_prompt,expected_output", + MODEL_TEMPLATE_GENERATON_OUTPUT) +async def test_get_gen_prompt(model, template, add_generation_prompt, + expected_output): + # Initialize the tokenizer + tokenizer = get_tokenizer(tokenizer_name=model) + + mock_args = Namespace(chat_template=template) + load_chat_template(mock_args, tokenizer) + + # Create a mock request object using keyword arguments + mock_request = ChatCompletionRequest( + model=model, + messages=TEST_MESSAGES, + add_generation_prompt=add_generation_prompt) + + # Call the function and get the result + result = tokenizer.apply_chat_template( + conversation=mock_request.messages, + tokenize=False, + add_generation_prompt=mock_request.add_generation_prompt) + + # Test assertion + assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}" + + +def test_health_endpoint(): + response = client.get("/health") + assert response.status_code == 200 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3a15e5d352c60..ef9a398585a7c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,6 +3,7 @@ import argparse import asyncio +import codecs import json import time from http import HTTPStatus @@ -14,7 +15,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response -from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -31,20 +31,55 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid -try: - import fastchat - from fastchat.conversation import Conversation, SeparatorStyle - from fastchat.model.model_adapter import get_conversation_template - _fastchat_available = True -except ImportError: - _fastchat_available = False - TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None app = fastapi.FastAPI() engine = None +response_role = None + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.") + parser.add_argument("--chat-template", + type=str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") + parser.add_argument("--response-role", + type=str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() def create_error_response(status_code: HTTPStatus, @@ -54,6 +89,25 @@ def create_error_response(status_code: HTTPStatus, status_code=status_code.value) +def load_chat_template(args, tokenizer): + if args.chat_template is not None: + try: + with open(args.chat_template, "r") as f: + chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + chat_template = codecs.decode(args.chat_template, "unicode_escape") + + tokenizer.chat_template = chat_template + logger.info( + f"Using supplied chat template:\n{tokenizer.chat_template}") + elif tokenizer.chat_template is not None: + logger.info(f"Using default chat template:\n{tokenizer.chat_template}") + else: + logger.warning("No chat template provided. Chat API will not work.") + + @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) @@ -69,53 +123,6 @@ async def check_model(request) -> Optional[JSONResponse]: return ret -async def get_gen_prompt(request) -> str: - if not _fastchat_available: - raise ModuleNotFoundError( - "fastchat is not installed. Please install fastchat to use " - "the chat completion and conversation APIs: `$ pip install fschat`" - ) - if version.parse(fastchat.__version__) < version.parse("0.2.23"): - raise ImportError( - f"fastchat version is low. Current version: {fastchat.__version__} " - "Please upgrade fastchat to use: `$ pip install -U fschat`") - - conv = get_conversation_template(request.model) - conv = Conversation( - name=conv.name, - system_template=conv.system_template, - system_message=conv.system_message, - roles=conv.roles, - messages=list(conv.messages), # prevent in-place modification - offset=conv.offset, - sep_style=SeparatorStyle(conv.sep_style), - sep=conv.sep, - sep2=conv.sep2, - stop_str=conv.stop_str, - stop_token_ids=conv.stop_token_ids, - ) - - if isinstance(request.messages, str): - prompt = request.messages - else: - for message in request.messages: - msg_role = message["role"] - if msg_role == "system": - conv.system_message = message["content"] - elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) - elif msg_role == "assistant": - conv.append_message(conv.roles[1], message["content"]) - else: - raise ValueError(f"Unknown role: {msg_role}") - - # Add a blank message for the assistant. - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - return prompt - - async def check_length( request: Union[ChatCompletionRequest, CompletionRequest], prompt: Optional[str] = None, @@ -207,7 +214,6 @@ async def create_chat_completion(request: ChatCompletionRequest, - function_call (Users should implement this by themselves) - logit_bias (to be supported by vLLM engine) """ - error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret @@ -217,7 +223,15 @@ async def create_chat_completion(request: ChatCompletionRequest, return create_error_response(HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported") - prompt = await get_gen_prompt(request) + try: + prompt = tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) + except Exception as e: + logger.error(f"Error in applying chat template from request: {str(e)}") + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -225,6 +239,7 @@ async def create_chat_completion(request: ChatCompletionRequest, model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) + chunk_object_type = "chat.completion.chunk" try: spaces_between_special_tokens = request.spaces_between_special_tokens sampling_params = SamplingParams( @@ -249,128 +264,162 @@ async def create_chat_completion(request: ChatCompletionRequest, result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) - def create_stream_response_json( - index: int, - text: str, - finish_reason: Optional[str] = None, - usage: Optional[UsageInfo] = None, - ) -> str: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=text), - finish_reason=finish_reason, - ) - response = ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - if usage is not None: - response.usage = usage - # exclude unset to leave details out of each sse - response_json = response.json(exclude_unset=True, ensure_ascii=False) - - return response_json + def get_role() -> str: + if request.add_generation_prompt: + return response_role + else: + return request.messages[-1]["role"] async def completion_stream_generator() -> AsyncGenerator[str, None]: - # First chunk with role + # Send first response for each request.n (index) with the role + role = get_role() for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role="assistant"), - finish_reason=None, - ) + index=i, delta=DeltaMessage(role=role), finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, choices=[choice_data], model=model_name) data = chunk.json(exclude_unset=True, ensure_ascii=False) yield f"data: {data}\n\n" + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + # Send response for each token for each request.n (index) previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - completion_tokens = len(output.token_ids) - previous_num_tokens[i] = completion_tokens - response_json = create_stream_response_json( - index=i, - text=delta_text, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: + + if finish_reason_sent[i]: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + completion_tokens = len(output.token_ids) + previous_num_tokens[i] = completion_tokens + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - response_json = create_stream_response_json( - index=i, - text="", - finish_reason=output.finish_reason, - usage=final_usage, - ) - yield f"data: {response_json}\n\n" + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=[], finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.json(exclude_unset=True, + exclude_none=True, + ensure_ascii=False) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" - # Streaming response - if request.stream: - return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream") - - # Non-streaming response - final_res: RequestOutput = None - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return create_error_response(HTTPStatus.BAD_REQUEST, - "Client disconnected") - final_res = res - assert final_res is not None - choices = [] - for output in final_res.outputs: - choice_data = ChatCompletionResponseChoice( - index=output.index, - message=ChatMessage(role="assistant", content=output.text), - finish_reason=output.finish_reason, + async def completion_full_generator(): + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + + choices = [] + role = get_role() + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, ) - choices.append(choice_data) - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) + return response + # Streaming response if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. - response_json = response.json(ensure_ascii=False) - - async def fake_stream_generator() -> AsyncGenerator[str, None]: - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(fake_stream_generator(), + return StreamingResponse(completion_stream_generator(), media_type="text/event-stream") - - return response + else: + return await completion_full_generator() @app.post("/v1/completions") @@ -642,34 +691,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") - parser.add_argument("--port", type=int, default=8000, help="port number") - parser.add_argument("--allow-credentials", - action="store_true", - help="allow credentials") - parser.add_argument("--allowed-origins", - type=json.loads, - default=["*"], - help="allowed origins") - parser.add_argument("--allowed-methods", - type=json.loads, - default=["*"], - help="allowed methods") - parser.add_argument("--allowed-headers", - type=json.loads, - default=["*"], - help="allowed headers") - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. If not " - "specified, the model name will be the same as " - "the huggingface name.") - - parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() + args = parse_args() app.add_middleware( CORSMiddleware, @@ -686,6 +708,8 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: else: served_model = args.model + response_role = args.response_role + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) @@ -696,6 +720,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, trust_remote_code=engine_model_config.trust_remote_code) + load_chat_template(args, tokenizer) uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 797f0a7115e6e..2aa567cb87034 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel): stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + add_generation_prompt: Optional[bool] = True + echo: Optional[bool] = False class CompletionRequest(BaseModel):