Skip to content

Commit

Permalink
allow tools to return any (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Dec 4, 2024
1 parent 3f0234f commit 1a72ca5
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema))

_(This example is complete, it can be run "as is")_

The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.tools.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.

If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)

Expand Down
23 changes: 7 additions & 16 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations as _annotations

import json
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
from typing import Annotated, Any, Literal, Union

import pydantic
import pydantic_core
from pydantic import TypeAdapter
from typing_extensions import TypeAlias, TypeAliasType

from . import _pydantic
from ._utils import now_utc as _now_utc
Expand Down Expand Up @@ -44,13 +42,7 @@ class UserPrompt:
"""Message type identifier, this type is available on all message as a discriminator."""


JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]'
if not TYPE_CHECKING:
# work around for https://github.com/pydantic/pydantic/issues/10873
# this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')

json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)


@dataclass
Expand All @@ -59,7 +51,7 @@ class ToolReturn:

tool_name: str
"""The name of the "tool" was called."""
content: JsonData
content: Any
"""The return value."""
tool_id: str | None = None
"""Optional tool identifier, this is used by some models including OpenAI."""
Expand All @@ -72,15 +64,14 @@ def model_response_str(self) -> str:
if isinstance(self.content, str):
return self.content
else:
content = json_ta.validate_python(self.content)
return json_ta.dump_json(content).decode()
return tool_return_ta.dump_json(self.content).decode()

def model_response_object(self) -> dict[str, JsonData]:
def model_response_object(self) -> dict[str, Any]:
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
if isinstance(self.content, dict):
return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType]
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
else:
return {'return_value': json_ta.validate_python(self.content)}
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}


@dataclass
Expand Down
19 changes: 6 additions & 13 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations as _annotations

import inspect
from collections.abc import Awaitable, Mapping, Sequence
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
from typing_extensions import Concatenate, ParamSpec, TypeAlias, final
from typing_extensions import Concatenate, ParamSpec, final

from . import _pydantic, _utils, messages
from .exceptions import ModelRetry, UnexpectedModelBehavior
Expand All @@ -23,12 +23,10 @@
'RunContext',
'ResultValidatorFunc',
'SystemPromptFunc',
'ToolReturnValue',
'ToolFuncContext',
'ToolFuncPlain',
'ToolFuncEither',
'ToolParams',
'JsonData',
'Tool',
)

Expand Down Expand Up @@ -75,17 +73,12 @@ class RunContext(Generic[AgentDeps]):
Usage `ResultValidator[AgentDeps, ResultData]`.
"""

JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
"""Type representing any JSON data."""

ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
"""Return value of a tool function."""
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
"""A tool function that takes `RunContext` as the first argument.
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
"""
ToolFuncPlain = Callable[ToolParams, ToolReturnValue]
ToolFuncPlain = Callable[ToolParams, Any]
"""A tool function that does not take `RunContext` as the first argument.
Usage `ToolPlainFunc[ToolParams]`.
Expand Down Expand Up @@ -146,8 +139,8 @@ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
function: The Python function to call as the tool.
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
name: Name of the tool, inferred from the function if left blank.
description: Description of the tool, inferred from the function if left blank.
name: Name of the tool, inferred from the function if `None`.
description: Description of the tool, inferred from the function if `None`.
"""
f = _pydantic.function_schema(function, takes_ctx)
self.function = function
Expand Down
56 changes: 53 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel, Field
from pydantic_core import PydanticSerializationError

from pydantic_ai import Agent, RunContext, Tool, UserError
from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
Expand Down Expand Up @@ -209,12 +210,14 @@ def test_docstring_google_no_body(set_event_loop: None):
)


class Foo(BaseModel):
x: int
y: str


def test_takes_just_model(set_event_loop: None):
agent = Agent()

class Foo(BaseModel):
x: int
y: str

@agent.tool_plain
def takes_just_model(model: Foo) -> str:
Expand Down Expand Up @@ -343,3 +346,50 @@ def plain_tool(x: int) -> int:
def test_init_plain_tool_invalid():
with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'):
Tool(ctx_tool, False)


def test_return_pydantic_model(set_event_loop: None):
agent = Agent('test')

@agent.tool_plain
def return_pydantic_model(x: int) -> Foo:
return Foo(x=x, y='a')

result = agent.run_sync('')
assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}')


def test_return_bytes(set_event_loop: None):
agent = Agent('test')

@agent.tool_plain
def return_pydantic_model() -> bytes:
return '🐈 Hello'.encode()

result = agent.run_sync('')
assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}')


def test_return_bytes_invalid(set_event_loop: None):
agent = Agent('test')

@agent.tool_plain
def return_pydantic_model() -> bytes:
return b'\00 \x81'

with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'):
agent.run_sync('')


def test_return_unknown(set_event_loop: None):
agent = Agent('test')

class Foobar:
pass

@agent.tool_plain
def return_pydantic_model() -> Foobar:
return Foobar()

with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'):
agent.run_sync('')
12 changes: 6 additions & 6 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ def ok_tool_plain(x: str) -> dict[str, str]:


@typed_agent.tool_plain
def ok_json_list(x: str) -> list[Union[str, int]]:
async def ok_json_list(x: str) -> list[Union[str, int]]:
return [x, 1]


@typed_agent.tool
async def ok_ctx(ctx: RunContext[MyDeps], x: str) -> list[int | str]:
return [ctx.deps.foo, ctx.deps.bar, x]


@typed_agent.tool
async def bad_tool1(ctx: RunContext[MyDeps], x: str) -> str:
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
Expand All @@ -76,11 +81,6 @@ async def bad_tool2(ctx: RunContext[int], x: str) -> str:
return f'{x} {ctx.deps}'


@typed_agent.tool_plain # type: ignore[arg-type]
async def bad_tool_return(x: int) -> list[MyDeps]:
return [MyDeps(1, x)]


with expect_error(ValueError):

@typed_agent.tool # type: ignore[arg-type]
Expand Down

0 comments on commit 1a72ca5

Please sign in to comment.