Skip to content

Commit

Permalink
[Tool parsing] Improve / correct mistral tool parsing (#10333)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored Nov 15, 2024
1 parent 554af92 commit 11cd1ae
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 59 deletions.
93 changes: 82 additions & 11 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy

import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)

from ...utils import check_logprobs_close

Expand Down Expand Up @@ -58,17 +62,69 @@
},
"required": ["city", "state", "unit"]
}
},
}, {
"type": "function",
"function": {
"name": "rewrite",
"description": "Rewrites text",
"parameters": {
"type": "object",
"required": [],
"properties": {
"text": {
"type": "string",
"description": "The input text to rewrite."
}
}
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
MSGS = [
{
"role": "system",
"content": "You are an assistant."
},
{
"role":
"user",
"content":
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
},
{
"role":
"assistant",
"content":
"",
"tool_calls": [{
"id": "bbc5b7ede",
"type": "function",
"function": {
"name":
"rewrite",
"arguments":
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
}
}]
},
{
"role": "tool",
"content":
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
"tool_call_id": "bbc5b7ede",
"name": "rewrite"
},
{
"role": "assistant",
"content": "---\n\nMy English needs improving, maybe I make errors"
},
{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}
]


@pytest.mark.parametrize("model", MODELS)
Expand Down Expand Up @@ -175,8 +231,23 @@ def test_mistral_function_calling(
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,

msgs = copy.deepcopy(MSGS)
outputs = vllm_model.model.chat(msgs,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)

assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
tokenizer = vllm_model.model.get_tokenizer()
tool_parser = MistralToolParser(tokenizer)

model_output = outputs[0].outputs[0].text.strip()
assert model_output.startswith(tool_parser.bot_token), model_output
parsed_message = tool_parser.extract_tool_calls(model_output, None)

assert parsed_message.tools_called
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
assert parsed_message.tool_calls[
0].function.name == "get_current_weather"
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None
39 changes: 5 additions & 34 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation

logger = init_logger(__name__)
Expand Down Expand Up @@ -127,41 +128,11 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get(
"tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(
tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i][
"tool_calls"] = validated_tool_calls
maybe_serialize_tool_calls(request)

if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
Expand Down
25 changes: 17 additions & 8 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, tokenizer: AnyTokenizer):
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
Expand All @@ -84,16 +84,25 @@ def extract_tool_calls(
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()

try:

# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)

# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
Expand All @@ -116,7 +125,7 @@ def extract_tool_calls(
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
content=tool_content)

def extract_tool_calls_streaming(
self,
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .mistral import MistralTokenizer
from .mistral import MistralTokenizer, maybe_serialize_tool_calls

__all__ = ["MistralTokenizer"]
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
70 changes: 66 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
Expand All @@ -29,6 +30,43 @@ class Encoding:
input_ids: List[int]


def maybe_serialize_tool_calls(request: ChatCompletionRequest):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break

request.messages[i]["tool_calls"] = validated_tool_calls


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
Expand Down Expand Up @@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self.is_tekken:
tokens = [
t for t in tokens
if t not in self.tokenizer._all_special_tokens
if (t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens)
]

if any(isinstance(t, bytes) for t in tokens):
Expand All @@ -246,7 +285,27 @@ def _token_to_id(t: str):
else:
decoded = "".join(tokens)
else:
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
regular_tokens: List[str] = []
decoded_list = []

for token in tokens:
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)

if regular_tokens:
decoded_list.append(
self.decode(regular_tokens)) # type: ignore

decoded = ''.join(decoded_list)

return decoded

Expand Down Expand Up @@ -274,8 +333,11 @@ def convert_ids_to_tokens(
assert self.is_tekken or self.is_spm, type(self.tokenizer)

if self.is_tekken:
# skip special tokens
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
# skip special tokens except tool call
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
]

tokens = [self.tokenizer.id_to_piece(id) for id in ids]

Expand Down

0 comments on commit 11cd1ae

Please sign in to comment.