diff --git a/docs/agents.md b/docs/agents.md index c75869a5..678ea59c 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -489,7 +489,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema)) _(This example is complete, it can be run "as is")_ -The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.tools.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. +The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 763f2716..30b2eef4 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1,15 +1,13 @@ from __future__ import annotations as _annotations import json -from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Annotated, Any, Literal, Union +from typing import Annotated, Any, Literal, Union import pydantic import pydantic_core from pydantic import TypeAdapter -from typing_extensions import TypeAlias, TypeAliasType from . import _pydantic from ._utils import now_utc as _now_utc @@ -44,13 +42,7 @@ class UserPrompt: """Message type identifier, this type is available on all message as a discriminator.""" -JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]' -if not TYPE_CHECKING: - # work around for https://github.com/pydantic/pydantic/issues/10873 - # this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file - JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]') - -json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData) +tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any) @dataclass @@ -59,7 +51,7 @@ class ToolReturn: tool_name: str """The name of the "tool" was called.""" - content: JsonData + content: Any """The return value.""" tool_id: str | None = None """Optional tool identifier, this is used by some models including OpenAI.""" @@ -72,15 +64,14 @@ def model_response_str(self) -> str: if isinstance(self.content, str): return self.content else: - content = json_ta.validate_python(self.content) - return json_ta.dump_json(content).decode() + return tool_return_ta.dump_json(self.content).decode() - def model_response_object(self) -> dict[str, JsonData]: + def model_response_object(self) -> dict[str, Any]: # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict if isinstance(self.content, dict): - return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType] + return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] else: - return {'return_value': json_ta.validate_python(self.content)} + return {'return_value': tool_return_ta.dump_python(self.content, mode='json')} @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 61995c3a..8021003e 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,13 +1,13 @@ from __future__ import annotations as _annotations import inspect -from collections.abc import Awaitable, Mapping, Sequence +from collections.abc import Awaitable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator -from typing_extensions import Concatenate, ParamSpec, TypeAlias, final +from typing_extensions import Concatenate, ParamSpec, final from . import _pydantic, _utils, messages from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -23,12 +23,10 @@ 'RunContext', 'ResultValidatorFunc', 'SystemPromptFunc', - 'ToolReturnValue', 'ToolFuncContext', 'ToolFuncPlain', 'ToolFuncEither', 'ToolParams', - 'JsonData', 'Tool', ) @@ -75,17 +73,12 @@ class RunContext(Generic[AgentDeps]): Usage `ResultValidator[AgentDeps, ResultData]`. """ -JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]' -"""Type representing any JSON data.""" - -ToolReturnValue = Union[JsonData, Awaitable[JsonData]] -"""Return value of a tool function.""" -ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue] +ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any] """A tool function that takes `RunContext` as the first argument. Usage `ToolContextFunc[AgentDeps, ToolParams]`. """ -ToolFuncPlain = Callable[ToolParams, ToolReturnValue] +ToolFuncPlain = Callable[ToolParams, Any] """A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. @@ -146,8 +139,8 @@ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: function: The Python function to call as the tool. takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument. max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. - name: Name of the tool, inferred from the function if left blank. - description: Description of the tool, inferred from the function if left blank. + name: Name of the tool, inferred from the function if `None`. + description: Description of the tool, inferred from the function if `None`. """ f = _pydantic.function_schema(function, takes_ctx) self.function = function diff --git a/tests/test_tools.py b/tests/test_tools.py index 640b8b1e..39905ae1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,6 +4,7 @@ import pytest from inline_snapshot import snapshot from pydantic import BaseModel, Field +from pydantic_core import PydanticSerializationError from pydantic_ai import Agent, RunContext, Tool, UserError from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse @@ -209,12 +210,14 @@ def test_docstring_google_no_body(set_event_loop: None): ) +class Foo(BaseModel): + x: int + y: str + + def test_takes_just_model(set_event_loop: None): agent = Agent() - class Foo(BaseModel): - x: int - y: str @agent.tool_plain def takes_just_model(model: Foo) -> str: @@ -343,3 +346,50 @@ def plain_tool(x: int) -> int: def test_init_plain_tool_invalid(): with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'): Tool(ctx_tool, False) + + +def test_return_pydantic_model(set_event_loop: None): + agent = Agent('test') + + @agent.tool_plain + def return_pydantic_model(x: int) -> Foo: + return Foo(x=x, y='a') + + result = agent.run_sync('') + assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}') + + +def test_return_bytes(set_event_loop: None): + agent = Agent('test') + + @agent.tool_plain + def return_pydantic_model() -> bytes: + return '🐈 Hello'.encode() + + result = agent.run_sync('') + assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}') + + +def test_return_bytes_invalid(set_event_loop: None): + agent = Agent('test') + + @agent.tool_plain + def return_pydantic_model() -> bytes: + return b'\00 \x81' + + with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'): + agent.run_sync('') + + +def test_return_unknown(set_event_loop: None): + agent = Agent('test') + + class Foobar: + pass + + @agent.tool_plain + def return_pydantic_model() -> Foobar: + return Foobar() + + with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'): + agent.run_sync('') diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 19bed49b..1daa92c3 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -61,10 +61,15 @@ def ok_tool_plain(x: str) -> dict[str, str]: @typed_agent.tool_plain -def ok_json_list(x: str) -> list[Union[str, int]]: +async def ok_json_list(x: str) -> list[Union[str, int]]: return [x, 1] +@typed_agent.tool +async def ok_ctx(ctx: RunContext[MyDeps], x: str) -> list[int | str]: + return [ctx.deps.foo, ctx.deps.bar, x] + + @typed_agent.tool async def bad_tool1(ctx: RunContext[MyDeps], x: str) -> str: total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined] @@ -76,11 +81,6 @@ async def bad_tool2(ctx: RunContext[int], x: str) -> str: return f'{x} {ctx.deps}' -@typed_agent.tool_plain # type: ignore[arg-type] -async def bad_tool_return(x: int) -> list[MyDeps]: - return [MyDeps(1, x)] - - with expect_error(ValueError): @typed_agent.tool # type: ignore[arg-type]