Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Tool calling parser for Granite 3.0 models #9027

Merged
merged 10 commits into from
Nov 7, 2024
44 changes: 26 additions & 18 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ this, unless explicitly specified.
:func: create_parser_for_docs
:prog: vllm serve
```
## Tool Calling in the Chat Completion API
### Named Function Calling
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.

To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.

### Config file

Expand Down Expand Up @@ -163,12 +156,22 @@ The order of priorities is `command line > config file values > defaults`.
---

## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.

vLLM supports named function calling and `auto` tool choice in the chat completion API. The `tool_choice` options `required` is **not yet supported** but on the roadmap.

It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.


### Named Function Calling
vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.

vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.

To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.


### Automatic Function Calling
To enable this feature, you should set the following flags:
Expand Down Expand Up @@ -242,6 +245,21 @@ it works better with vLLM.

Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`

#### IBM Granite

Supported models:
* `ibm-granite/granite-3.0-8b-instruct`

Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`

`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.

* `ibm-granite/granite-20b-functioncalling`
maxdebayser marked this conversation as resolved.
Show resolved Hide resolved

Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`

`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.


#### InternLM Models (`internlm`)

Expand All @@ -264,16 +282,6 @@ AI21's Jamba-1.5 models are supported.
Flags: `--tool-call-parser jamba`


#### IBM Granite (`granite-20b-fc`)

Supported models:
* `ibm-granite/granite-20b-functioncalling`

Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`

The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.


### How to write a tool parser plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
Expand Down
40 changes: 40 additions & 0 deletions examples/tool_chat_template_granite.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{%- if tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>
' }}
{%- for tool in tools %}
{{- tool | tojson(indent=4) }}
{%- if not loop.last %}
{{- '

' }}
{%- endif %}
{%- endfor %}
{{- '<|end_of_text|>
' }}
{%- endif %}

{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'user' %}
{{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %}
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
{% for tc in message.tool_calls %}
{{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }}
{% endfor %}
{{- '<|end_of_text|>
' }}
{%- elif message['role'] == 'assistant' %}
{{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %}
{{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- endif %}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
{%- endif %}
{%- endfor %}
34 changes: 21 additions & 13 deletions tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],

# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]

CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
Expand Down Expand Up @@ -88,18 +88,26 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
## FIXME: temporary disabled due to lack of hardware specification
## for individual runs
#"granite20b": {
# "model":
# "ibm-granite/granite-20b-functioncalling",
# "arguments": [
# "--tool-call-parser", "granite-20b-fc", "--chat-template",
# str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja")
# ],
# "supports_parallel":
# False,
#},
"granite20b": {
"model":
"mbayser/granite-20b-functioncalling-FP8-KV",
"arguments": [
"--tool-call-parser", "granite-20b-fc", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_granite_20b_fc.jinja"),
"--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20"
],
"supports_parallel":
False,
},
"granite8b": {
"model":
"ibm-granite/granite-3.0-8b-instruct",
"arguments": [
"--tool-call-parser", "granite", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja")
],
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
Expand All @@ -8,6 +9,6 @@

__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser",
"Llama3JsonToolParser", "JambaToolParser"
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser"
]
215 changes: 215 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import json
from typing import Dict, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)


@ToolParserManager.register_module("granite")
class GraniteToolParser(ToolParser):
"""
Tool call parser for the granite 3.0 models. Intended
for use with the examples/tool_chat_template_granite.jinja
template.

Used when --enable-auto-tool-choice --tool-call-parser granite
are all set
"""

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
stripped = model_output.strip()
if not stripped or stripped[0] != '[':
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
raw_function_calls = json.loads(stripped)
if not isinstance(raw_function_calls, list):
raise Exception(
f"Expected dict or list, got {type(raw_function_calls)}")

logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]),
),
) for function_call in raw_function_calls
]

return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=None,
)

except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:

start_idx = consume_space(0, current_text)
if not current_text or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text)

# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = None
is_complete = None
try:
tool_calls, end_idx = partial_json_loads(
current_text[start_idx:], flags)
if type(tool_calls) is list:
tool_call_arr = tool_calls
else:
return DeltaMessage(content=delta_text)

is_complete = [True] * len(tool_calls)
if not is_complete_json(
current_text[start_idx:start_idx + end_idx]):
is_complete[-1] = False
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None

# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if not tool_call_arr:
return None

# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id]

delta = None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
if len(tool_call_arr) > self.current_tool_id + 1:

# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]

logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff

# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta

# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:

delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True

# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")

if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")

argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]

if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff

self.prev_tool_call_arr = tool_call_arr
return delta

except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None