diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 8ec9b05b2c521..e447469e33410 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -19,7 +19,7 @@ class ServerConfig(TypedDict): CONFIGS: Dict[str, ServerConfig] = { "hermes": { "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", "arguments": [ "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 970262a4bd358..374196044b7e8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel): function: Optional[DeltaFunctionCall] = None -# the initial delta that gets sent once a new tool call is started; -class InitialDeltaToolCall(DeltaToolCall): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") - type: Literal["function"] = "function" - index: int - - class ExtractedToolCallInformation(BaseModel): # indicate if tools were called tools_called: bool diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 78f355228012f..8ed81e9c88cb2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -271,9 +271,13 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(role=role), + delta=DeltaMessage( + role=role, + content="", + ), logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse( diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index b0807e6f1e782..873f615d43257 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -20,7 +20,6 @@ def __init__(self, tokenizer: AnyTokenizer): # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [] self.model_tokenizer = tokenizer diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 7afbca7162edf..bde9b47ce60d5 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -8,14 +8,14 @@ from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - InitialDeltaToolCall, ToolCall) + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid logger = init_logger(__name__) @@ -34,7 +34,6 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list @@ -168,7 +167,6 @@ def extract_tool_calls_streaming( # set cursors and state appropriately self.current_tool_id += 1 self.current_tool_name_sent = False - self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") logger.debug("Starting on a new tool %s", self.current_tool_id) @@ -218,24 +216,16 @@ def extract_tool_calls_streaming( logger.debug('not enough tokens to parse into JSON yet') return None - # case - we haven't sent the initial delta with the tool call ID - # (it will be sent) - if not self.current_tool_initial_sent: - self.current_tool_initial_sent = True - return DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. - elif not self.current_tool_name_sent: + if not self.current_tool_name_sent: function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index d48770c792e98..4b0e1c91df97c 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -8,14 +8,14 @@ from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - InitialDeltaToolCall, ToolCall) + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid logger = init_logger(__name__) @@ -25,7 +25,7 @@ class MistralToolParser(ToolParser): Tool call parser for Mistral 7B Instruct v0.3, intended for use with the examples/tool_chat_template_mistral.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set """ def __init__(self, tokenizer: AnyTokenizer): @@ -42,7 +42,6 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" @@ -91,7 +90,6 @@ def extract_tool_calls(self, except Exception as e: logger.error("Error in extracting tool call from response: %s", e) - print("ERROR", e) # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, tool_calls=[], @@ -109,7 +107,7 @@ def extract_tool_calls_streaming( # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool - if self.bot_token_id not in current_token_ids: + if self.bot_token not in current_text: return DeltaMessage(content=delta_text) # if the tool call token ID IS in the tokens generated so far, that @@ -134,7 +132,7 @@ def extract_tool_calls_streaming( # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[1] + parsable_arr = current_text.split(self.bot_token)[-1] # tool calls are generated in an array, so do partial JSON # parsing on the entire array @@ -186,31 +184,22 @@ def extract_tool_calls_streaming( # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False - self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") logger.debug("starting on new tool %d", self.current_tool_id) return delta # case: update an existing tool - this is handled below - # if the current tool initial data incl. the id, type=function - # and idx not sent, send that - if not self.current_tool_initial_sent: - self.current_tool_initial_sent = True - delta = DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - # if the current tool name hasn't been sent, send if available # - otherwise send nothing - elif not self.current_tool_name_sent: + if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True))