-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tool parsing] Improve / correct mistral tool parsing #10333
[Tool parsing] Improve / correct mistral tool parsing #10333
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
@@ -58,17 +61,62 @@ | |||
}, | |||
"required": ["city", "state", "unit"] | |||
} | |||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make test much more difficult, complex to show the community to what extent function calling can be used with Mistral models
|
||
model_output = outputs[0].outputs[0].text.strip() | ||
assert model_output.startswith(tool_parser.bot_token), model_output | ||
parsed_message = tool_parser.extract_tool_calls(model_output, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaner to let the parser take care of correctly extracting the dict
break | ||
request.messages[i][ | ||
"tool_calls"] = validated_tool_calls | ||
maybe_serialize_tool_calls(request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moving this out of serving_chat.py
just to clean the method a bit. This is a very general method and the error correction here is very mistral specific, so probably better placed in tokenizers.mistral.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical).
Factoring the logic in the function like you did is a good solution that would still work with other non-chat-template models 👍
|
||
request.messages[i]["tool_calls"] = validated_tool_calls | ||
|
||
|
||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As proposed by @gcalmettes here: #9059 (comment)
We don't parse away the [TOOL_CALLS] token for neither tekken nor spm so that function calls can be correctly parsed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this PR! I think it's a lot cleaner now.
break | ||
request.messages[i][ | ||
"tool_calls"] = validated_tool_calls | ||
maybe_serialize_tool_calls(request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical).
Factoring the logic in the function like you did is a good solution that would still work with other non-chat-template models 👍
@@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: | |||
if self.is_tekken: | |||
tokens = [ | |||
t for t in tokens | |||
if t not in self.tokenizer._all_special_tokens | |||
if (t is SpecialTokens.tool_calls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that after further testing on my end, I found a edge case where not skipping the [TOOL_CALLS]
token here can potentially mess up the intended output:
- when requiring structured output by specifying
response_format=json_object
orresponse_format=json_schema
, the[TOOL_CALL]
token is still emitted in some cases even though we are not providing any tools to the model, and therefore the generated output is no more compliant withjson
. I have tested and observed this with all the vllm supported structured output backends (lm-format-enforcer
/outlines
). Note that this only happens if there is no mention that we expect JSON responses from the model in the system prompt.
If we can find a way to not filter out the SpecialTokens.tool_calls
token only when function calling is required (based on the presence of tools in the request for example), that would be best. However I haven't found a clean way yet to pass this information to the convert_tokens_to_string
method without having to change the signature of the method ...
I have an easy reproducible example of this problem that I can share to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note! Would be great if you could share an easy repro
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten please find below a scenario were it will break (and further below the small change in prompt that would make the code work, because of added guidance to the model). Note that the code requires lm-format_enforcer
version 0.10.9
so it is compatible with the MistralTokenizer.
However, after further investigation, I know now how to fix it (I'm preparing a PR, I'll tag you for your review) ! In fact the problem was present before but "masked" by the fact that the [TOOL_CALL]
was skipped in the convert_tokens_to_string
method, so your PR made possible to expose the problem 😉 . (the root cause is that all the structured output librairies filter out the special tokens to build their tree of possible tokens, e.g.: this check in lm-format-enforcer but the current vllm MistralTokenizer
does not correctly populate the methods that the librairies use for that. The fix is easy, and I have tested it with success.)
"""
vllm server started with the following arguments:
--guided-decoding-backend=lm-format-enforcer
--enable-auto-tool-choice
--tool-call-parser=mistral
--tokenizer-mode=mistral
"""
from openai import OpenAI
from pydantic import BaseModel
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="none",
)
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
completion = client.beta.chat.completions.parse(
model="mistralai/Pixtral-12B-2409",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
)
# the response will break as `[TOOL_CALLS]` is present at the beginning of the response
event = completion.choices[0].message.parsed
print(event.__dict__)
Guiding the model to output JSON by changing the system prompt as below is enough so that the model actually does not produce a tool_call token :
{"role": "system", "content": "Extract the event information. Respond as JSON."},
This PR is heavily inspired / copied from what @gcalmettes nicely summarized here: #9059 (comment) and in following messages. Thanks a ton for the nice investigation and great ideas of how to improve Mistral function calling.
Based on @gcalmettes's idea here #9059 (comment) both tekken models (mistral-nemo) and spm models (mistral-8b) output the [TOOL_CALLS] token so that it can be consumed by the tool parser and hence allow for more robust function calling parsing, e.g.:
and then ping the model e.g. via:
This PR should then also finally close: #9059.