Skip to content

Commit

Permalink
[Core] Compatibility with function call style API (Azure OpenAI and G…
Browse files Browse the repository at this point in the history
…emini) (#1227)

* #1206

* doc

* add test for azure openai

* prior to

* filter for versions

* up to

* literal type

* update doc
  • Loading branch information
ekzhu authored and joshkyh committed Jan 17, 2024
1 parent 2dd11e9 commit dec108e
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 2 deletions.
21 changes: 20 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,7 @@ def register_for_llm(
*,
name: Optional[str] = None,
description: Optional[str] = None,
api_style: Literal["function", "tool"] = "tool",
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be used by an agent.
Expand All @@ -1713,6 +1714,10 @@ def register_for_llm(
name (optional(str)): name of the function. If None, the function name will be used (default: None).
description (optional(str)): description of the function (default: None). It is mandatory
for the initial decorator, but the following ones can omit it.
api_style: (literal): the API style for function call.
For Azure OpenAI API, use version 2023-12-01-preview or later.
`"function"` style will be deprecated. For earlier version use
`"function"` if `"tool"` doesn't work.
Returns:
The decorator for registering a function to be used by an agent.
Expand All @@ -1726,6 +1731,14 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c
return a + str(b * c)
```
For Azure OpenAI versions 2023-10-01-preview and earlier, set `api_style`
to `"function"` if `"tool"` doesn't work:
```
@agent2.register_for_llm(api_style="function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str:
return a + str(b * c)
```
"""

def _decorator(func: F) -> F:
Expand Down Expand Up @@ -1762,7 +1775,13 @@ def _decorator(func: F) -> F:
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

self.update_tool_signature(f, is_remove=False)
if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
elif api_style == "tool":
self.update_tool_signature(f, is_remove=False)
else:
raise ValueError(f"Unsupported API style: {api_style}")

return func

Expand Down
71 changes: 71 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,77 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> s
assert agent3.llm_config["tools"] == expected3


def test_register_for_llm_api_style_function():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})

@agent3.register_for_llm(api_style="function")
@agent2.register_for_llm(name="python", api_style="function")
@agent1.register_for_llm(
description="run cell in ipython and return the execution result.", api_style="function"
)
def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:
pass

expected1 = [
{
"description": "run cell in ipython and return the execution result.",
"name": "exec_python",
"parameters": {
"type": "object",
"properties": {
"cell": {
"type": "string",
"description": "Valid Python cell to execute.",
}
},
"required": ["cell"],
},
}
]
expected2 = copy.deepcopy(expected1)
expected2[0]["name"] = "python"
expected3 = expected2

assert agent1.llm_config["functions"] == expected1
assert agent2.llm_config["functions"] == expected2
assert agent3.llm_config["functions"] == expected3

@agent3.register_for_llm(api_style="function")
@agent2.register_for_llm(api_style="function")
@agent1.register_for_llm(
name="sh", description="run a shell script and return the execution result.", api_style="function"
)
async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> str:
pass

expected1 = expected1 + [
{
"name": "sh",
"description": "run a shell script and return the execution result.",
"parameters": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "Valid shell script to execute.",
}
},
"required": ["script"],
},
}
]
expected2 = expected2 + [expected1[1]]
expected3 = expected3 + [expected1[1]]

assert agent1.llm_config["functions"] == expected1
assert agent2.llm_config["functions"] == expected2
assert agent3.llm_config["functions"] == expected3


def test_register_for_llm_without_description():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
Expand Down
56 changes: 56 additions & 0 deletions test/agentchat/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,62 @@ def test_eval_math_responses():
print(eval_math_responses(**arguments))


@pytest.mark.skipif(skip_openai or not TOOL_ENABLED, reason="openai>=1.1.0 not installed or requested to skip")
def test_eval_math_responses_api_style_function():
config_list = autogen.config_list_from_models(
KEY_LOC,
model_list=["gpt-4-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k"],
filter_dict={
"api_type": ["azure"],
"api_version": ["2023-10-01-preview", "2023-09-01-preview", "2023-08-01-preview", "2023-07-01-preview"],
},
)
functions = [
{
"name": "eval_math_responses",
"description": "Select a response for a math problem using voting, and check if the response is correct if the solution is provided",
"parameters": {
"type": "object",
"properties": {
"responses": {
"type": "array",
"items": {"type": "string"},
"description": "The responses in a list",
},
"solution": {
"type": "string",
"description": "The canonical solution",
},
},
"required": ["responses"],
},
},
]
client = autogen.OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[
{
"role": "user",
"content": 'evaluate the math responses ["1", "5/2", "5/2"] against the true answer \\frac{5}{2}',
},
],
functions=functions,
)
print(response)
responses = client.extract_text_or_completion_object(response)
print(responses[0])
function_call = responses[0].function_call
name, arguments = function_call.name, json.loads(function_call.arguments)
assert name == "eval_math_responses"
print(arguments["responses"])
# if isinstance(arguments["responses"], str):
# arguments["responses"] = json.loads(arguments["responses"])
arguments["responses"] = [f"\\boxed{{{x}}}" for x in arguments["responses"]]
print(arguments["responses"])
arguments["solution"] = f"\\boxed{{{arguments['solution']}}}"
print(eval_math_responses(**arguments))


@pytest.mark.skipif(
skip_openai or not TOOL_ENABLED or not sys.version.startswith("3.10"),
reason="do not run if openai is <1.1.0 or py!=3.10 or requested to skip",
Expand Down
6 changes: 5 additions & 1 deletion website/docs/Use-Cases/agent_chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol)
else:
raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}")


# NOTE: for Azure OpenAI, please use API version 2023-12-01-preview or later as
# support for earlier versions will be deprecated.
# For API versions 2023-10-01-preview or earlier you may
# need to set `api_style="function"` in the decorator if the default value does not work:
# `register_for_llm(description=..., api_style="function")`.
@user_proxy.register_for_execution()
@chatbot.register_for_llm(description="Currency exchange calculator.")
def currency_calculator(
Expand Down

0 comments on commit dec108e

Please sign in to comment.