Skip to content

Commit

Permalink
[Core] [Tool Call] adjust conversable agent to support tool_calls (#974)
Browse files Browse the repository at this point in the history
* adjust conversable and compressible agents to support tool_calls

* split out tools into their own reply def

* copilot typo

* address review comments

* revert compressible_agent and token_count_utils calls

* cleanup terminate check and remove unnecessary code

* doc search and update

* return function/tool calls as interrupted when user provides a reply to a tool call request

* fix tool name reference

* fix formatting

* fix initiate receiving a dict

* missed changed roled

* ignore incoming role, more similiar to existing code

* consistency

* redundant to_dict

* fix todo comment

* uneeded change

* handle dict reply in groupchat

* Fix generate_tool_call_calls_reply_comment

* change method annotation for register_for_llm from functions to tools

* typo autogen/agentchat/conversable_agent.py

Co-authored-by: Chi Wang <[email protected]>

* add deprecation comments for function_call

* tweak doc strings

* switch to ToolFunction type

* update the return to

* fix generate_init_message return type

* Revert "fix generate_init_message return type"

This reverts commit 645ba8b.

* undo force init to dict

* fix notebooks and groupchat tool handling

* fix type

* use get for key error

* fix teachable to pull content from dict

* change single message tool response

* cleanup unnessary changes

* little better tool response concatenation

* update tools tests

* add skip openai check to tools tests

* fix nits

* move func name normalization to oai_reply and assert configured names

* fix whitespace

* remove extra normalize

* tool name is now normalized in the generate_reply function, so will not be incorrect when sent to receive

* validate function names in init and expand comments for validation methods

* fix dict comprehension

* Dummy llm config for unit tests

* handle tool_calls set to None

* fix tool name reference

* method operates on responses not calls

---------

Co-authored-by: Yiran Wu <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
4 people authored Jan 6, 2024
1 parent e673500 commit 40dbf31
Show file tree
Hide file tree
Showing 9 changed files with 778 additions and 211 deletions.
349 changes: 309 additions & 40 deletions autogen/agentchat/conversable_agent.py

Large diffs are not rendered by default.

31 changes: 25 additions & 6 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,21 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen
"Or, use direct communication instead."
)

if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
if (
self.func_call_filter
and self.messages
and ("function_call" in self.messages[-1] or "tool_calls" in self.messages[-1])
):
funcs = []
if "function_call" in self.messages[-1]:
funcs += [self.messages[-1]["function_call"]["name"]]
if "tool_calls" in self.messages[-1]:
funcs += [
tool["function"]["name"] for tool in self.messages[-1]["tool_calls"] if tool["type"] == "function"
]

# find agents with the right function_map which contains the function name
agents = [
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
]
agents = [agent for agent in self.agents if agent.can_execute_function(funcs)]
if len(agents) == 1:
# only one agent can execute the function
return agents[0], agents
Expand All @@ -170,7 +180,7 @@ def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agen
return agents[0], agents
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['function_call']['name']}. "
f"No agent can execute the function {', '.join(funcs)}. "
"Please check the function_map of the agents."
)
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
Expand All @@ -193,7 +203,14 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
return selected_agent
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
context = self.messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]

# If last message is a tool call or function call, blank the call so the api doesn't throw
messages = self.messages.copy()
if messages[-1].get("function_call", False):
messages[-1] = dict(messages[-1], function_call=None)
if messages[-1].get("tool_calls", False):
messages[-1] = dict(messages[-1], tool_calls=None)
context = messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
final, name = selector.generate_oai_reply(context)

if not final:
Expand Down Expand Up @@ -275,6 +292,8 @@ def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agen
Dict: a counter for mentioned agents.
"""
# Cast message content to str
if isinstance(message_content, dict):
message_content = message_content["content"]
message_content = content_str(message_content)

mentions = dict()
Expand Down
17 changes: 13 additions & 4 deletions autogen/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class Function(BaseModel):
parameters: Annotated[Parameters, Field(description="Parameters of the function")]


class ToolFunction(BaseModel):
"""A function under tool as defined by the OpenAI API."""

type: Literal["function"] = "function"
function: Annotated[Function, Field(description="Function under tool")]


def get_parameter_json_schema(
k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
) -> JsonSchemaValue:
Expand Down Expand Up @@ -260,10 +267,12 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet

parameters = get_parameters(required, param_annotations, default_values=default_values)

function = Function(
description=description,
name=fname,
parameters=parameters,
function = ToolFunction(
function=Function(
description=description,
name=fname,
parameters=parameters,
)
)

return model_dump(function)
Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def yes_or_no_filter(context, response):

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
# If streaming is enabled, has messages, and does not have functions or tools, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params:
if params.get("stream", False) and "messages" in params and "functions" not in params and "tools" not in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
Expand Down Expand Up @@ -352,8 +352,8 @@ def _completions_create(self, client, params):

response.choices.append(choice)
else:
# If streaming is not enabled or using functions, send a regular chat completion request
# Functions are not supported, so ensure streaming is disabled
# If streaming is not enabled, using functions, or tools, send a regular chat completion request
# Functions and Tools are not supported, so ensure streaming is disabled
params = params.copy()
params["stream"] = False
response = completions.create(**params)
Expand Down
116 changes: 67 additions & 49 deletions notebook/agentchat_function_call_currency_calculator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,21 @@
{
"data": {
"text/plain": [
"[{'description': 'Currency exchange calculator.',\n",
" 'name': 'currency_calculator',\n",
" 'parameters': {'type': 'object',\n",
" 'properties': {'base_amount': {'type': 'number',\n",
" 'description': 'Amount of currency in base_currency'},\n",
" 'base_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'USD',\n",
" 'description': 'Base currency'},\n",
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'EUR',\n",
" 'description': 'Quote currency'}},\n",
" 'required': ['base_amount']}}]"
"[{'type': 'function',\n",
" 'function': {'description': 'Currency exchange calculator.',\n",
" 'name': 'currency_calculator',\n",
" 'parameters': {'type': 'object',\n",
" 'properties': {'base_amount': {'type': 'number',\n",
" 'description': 'Amount of currency in base_currency'},\n",
" 'base_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'USD',\n",
" 'description': 'Base currency'},\n",
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'EUR',\n",
" 'description': 'Quote currency'}},\n",
" 'required': ['base_amount']}}}]"
]
},
"execution_count": 4,
Expand All @@ -207,7 +208,7 @@
}
],
"source": [
"chatbot.llm_config[\"functions\"]"
"chatbot.llm_config[\"tools\"]"
]
},
{
Expand Down Expand Up @@ -259,10 +260,14 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_2mZCDF9fe8WJh6SveIwdGGEy): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
"\u001b[32m********************************************************\u001b[0m\n",
"{\n",
" \"base_amount\": 123.45,\n",
" \"base_currency\": \"USD\",\n",
" \"quote_currency\": \"EUR\"\n",
"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
Expand All @@ -276,7 +281,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"123.45 USD is equivalent to approximately 112.23 EUR.\n",
"123.45 USD is approximately 112.23 EUR.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
Expand Down Expand Up @@ -370,27 +375,28 @@
{
"data": {
"text/plain": [
"[{'description': 'Currency exchange calculator.',\n",
" 'name': 'currency_calculator',\n",
" 'parameters': {'type': 'object',\n",
" 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n",
" 'enum': ['USD', 'EUR'],\n",
" 'title': 'Currency',\n",
" 'type': 'string'},\n",
" 'amount': {'default': 0,\n",
" 'description': 'Amount of currency',\n",
" 'minimum': 0.0,\n",
" 'title': 'Amount',\n",
" 'type': 'number'}},\n",
" 'required': ['currency'],\n",
" 'title': 'Currency',\n",
" 'type': 'object',\n",
" 'description': 'Base currency: amount and currency symbol'},\n",
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'USD',\n",
" 'description': 'Quote currency symbol'}},\n",
" 'required': ['base']}}]"
"[{'type': 'function',\n",
" 'function': {'description': 'Currency exchange calculator.',\n",
" 'name': 'currency_calculator',\n",
" 'parameters': {'type': 'object',\n",
" 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n",
" 'enum': ['USD', 'EUR'],\n",
" 'title': 'Currency',\n",
" 'type': 'string'},\n",
" 'amount': {'default': 0,\n",
" 'description': 'Amount of currency',\n",
" 'minimum': 0.0,\n",
" 'title': 'Amount',\n",
" 'type': 'number'}},\n",
" 'required': ['currency'],\n",
" 'title': 'Currency',\n",
" 'type': 'object',\n",
" 'description': 'Base currency: amount and currency symbol'},\n",
" 'quote_currency': {'enum': ['USD', 'EUR'],\n",
" 'type': 'string',\n",
" 'default': 'USD',\n",
" 'description': 'Quote currency symbol'}},\n",
" 'required': ['base']}}}]"
]
},
"execution_count": 8,
Expand All @@ -399,7 +405,7 @@
}
],
"source": [
"chatbot.llm_config[\"functions\"]"
"chatbot.llm_config[\"tools\"]"
]
},
{
Expand All @@ -419,10 +425,16 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_MLtsPcVJXhdpvDPNNxfTB3OB): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
"\u001b[32m********************************************************\u001b[0m\n",
"{\n",
" \"base\": {\n",
" \"currency\": \"EUR\",\n",
" \"amount\": 112.23\n",
" },\n",
" \"quote_currency\": \"USD\"\n",
"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
Expand All @@ -436,7 +448,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"112.23 Euros is approximately 123.45 US Dollars.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
Expand Down Expand Up @@ -477,10 +489,16 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested function Call: currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_WrBjnoLeXilBPuj9nTJLM5wh): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
"\u001b[32m********************************************************\u001b[0m\n",
"{\n",
" \"base\": {\n",
" \"currency\": \"USD\",\n",
" \"amount\": 123.45\n",
" },\n",
" \"quote_currency\": \"EUR\"\n",
"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
Expand Down Expand Up @@ -543,7 +561,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 40dbf31

Please sign in to comment.