diff --git a/.gitignore b/.gitignore index 593a5e22..f7cc9b3d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__ /scratch/ /.coverage env*/ +/TODO.md diff --git a/pydantic_ai/_system_prompt.py b/pydantic_ai/_system_prompt.py index 4cd8e64e..5910fc36 100644 --- a/pydantic_ai/_system_prompt.py +++ b/pydantic_ai/_system_prompt.py @@ -3,16 +3,16 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass -from typing import Callable, Generic, Union +from typing import Any, Callable, Generic, Union, cast -from . import _utils, retrievers as _r -from .retrievers import AgentDeps +from . import _utils +from .retrievers import AgentDeps, CallContext -# This is basically a function that may or maybe not take `CallInfo` as an argument, and may or may not be async. -# Usage `SystemPromptFunc[AgentDependencies]` +# A function that may or maybe not take `CallInfo` as an argument, and may or may not be async. +# Usage `SystemPromptFunc[AgentDeps]` SystemPromptFunc = Union[ - Callable[[_r.CallContext[AgentDeps]], str], - Callable[[_r.CallContext[AgentDeps]], Awaitable[str]], + Callable[[CallContext[AgentDeps]], str], + Callable[[CallContext[AgentDeps]], Awaitable[str]], Callable[[], str], Callable[[], Awaitable[str]], ] @@ -21,23 +21,22 @@ @dataclass class SystemPromptRunner(Generic[AgentDeps]): function: SystemPromptFunc[AgentDeps] - takes_ctx: bool = False - is_async: bool = False + _takes_ctx: bool = False + _is_async: bool = False def __post_init__(self): - self.takes_ctx = len(inspect.signature(self.function).parameters) > 0 - self.is_async = inspect.iscoroutinefunction(self.function) + self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 + self._is_async = inspect.iscoroutinefunction(self.function) async def run(self, deps: AgentDeps) -> str: - if self.takes_ctx: - args = (_r.CallContext(deps, 0),) + if self._takes_ctx: + args = (CallContext(deps, 0),) else: args = () - if self.is_async: - return await self.function(*args) # pyright: ignore[reportGeneralTypeIssues,reportUnknownVariableType] + if self._is_async: + function = cast(Callable[[Any], Awaitable[str]], self.function) + return await function(*args) else: - return await _utils.run_in_executor( - self.function, # pyright: ignore[reportArgumentType,reportReturnType] - *args, - ) + function = cast(Callable[[Any], str], self.function) + return await _utils.run_in_executor(function, *args) diff --git a/pydantic_ai/_utils.py b/pydantic_ai/_utils.py index 3a2b5402..e6a05311 100644 --- a/pydantic_ai/_utils.py +++ b/pydantic_ai/_utils.py @@ -91,7 +91,7 @@ class Some(Generic[_T]): class Either(Generic[_Left, _Right]): - """Two member Union that records which member was set. + """Two member Union that records which member was set, this is analogous to Rust enums with two variants. Usage: @@ -109,14 +109,19 @@ def __init__(self, *, left: _Left) -> None: ... @overload def __init__(self, *, right: _Right) -> None: ... - def __init__(self, *, left: _Left | None = None, right: _Right | None = None) -> None: - if (left is not None and right is not None) or (left is None and right is None): - raise TypeError('Either must have exactly one value') - self._left = left - self._right = right + def __init__(self, **kwargs: Any) -> None: + keys = set(kwargs.keys()) + if keys == {'left'}: + self._left: Option[_Left] = Some(kwargs['left']) + self._right: _Right | None = None + elif keys == {'right'}: + self._left = None + self._right = kwargs['right'] + else: + raise TypeError('Either must receive exactly one value - `left` or `right`') @property - def left(self) -> _Left | None: + def left(self) -> Option[_Left]: return self._left @property @@ -129,4 +134,4 @@ def is_left(self) -> bool: return self._left is not None def whichever(self) -> _Left | _Right: - return self._left or self.right + return self._left.value if self._left is not None else self.right diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 0a9d55b6..6d0af78b 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -2,11 +2,13 @@ import asyncio from collections.abc import Awaitable, Sequence +from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, cast, overload from typing_extensions import assert_never from . import _system_prompt, _utils, messages as _messages, models as _models, result as _result, retrievers as _r +from .models import Model from .result import ResultData from .retrievers import AgentDeps @@ -14,20 +16,22 @@ KnownModelName = Literal['openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo'] +@dataclass(init=False) class Agent(Generic[AgentDeps, ResultData]): """Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM.""" # slots mostly for my sanity — knowing what attributes are available - __slots__ = ( - '_model', - '_result_tool', - '_allow_text_result', - '_system_prompts', - '_retrievers', - '_default_retries', - '_system_prompt_functions', - '_default_deps', - ) + _model: Model | None + _result_tool: _result.ResultSchema[ResultData] | None + _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] + _allow_text_result: bool + _system_prompts: tuple[str, ...] + _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] + _default_retries: int + _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] + _default_deps: AgentDeps + _max_result_retries: int + _current_result_retry: int def __init__( self, @@ -45,10 +49,7 @@ def __init__( self._model = _models.infer_model(model) if model is not None else None self._result_tool = _result.ResultSchema[result_type].build( - result_type, - result_tool_name, - result_tool_description, - result_retries if result_retries is not None else retries, + result_type, result_tool_name, result_tool_description ) # if the result tool is None, or its schema allows `str`, we allow plain text results self._allow_text_result = self._result_tool is None or self._result_tool.allow_text_result @@ -57,7 +58,10 @@ def __init__( self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {} self._default_deps = cast(AgentDeps, None if deps == () else deps) self._default_retries = retries - self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = [] + self._system_prompt_functions = [] + self._max_result_retries = result_retries if result_retries is not None else retries + self._current_result_retry = 0 + self._result_validators = [] async def run( self, @@ -142,6 +146,13 @@ def system_prompt( self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func)) return func + def result_validator( + self, func: _result.ResultValidatorFunc[AgentDeps, ResultData] + ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]: + """Decorator to register a result validator function.""" + self._result_validators.append(_result.ResultValidator(func)) + return func + @overload def retriever_context(self, func: _r.RetrieverContextFunc[AgentDeps, _r.P], /) -> _r.Retriever[AgentDeps, _r.P]: ... @@ -225,8 +236,12 @@ async def _handle_model_response( if call is not None: either = self._result_tool.validate(call) if result_data := either.left: - return _utils.Some(result_data) + either = await self._validate_result(result_data.value, deps, call) + + if result_data := either.left: + return _utils.Some(result_data.value) else: + self._incr_result_retry() messages.append(either.right) return None @@ -242,6 +257,22 @@ async def _handle_model_response( else: assert_never(llm_message) + async def _validate_result( + self, result: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall + ) -> _utils.Either[ResultData, _messages.ToolRetry]: + for validator in self._result_validators: + either = await validator.validate(result, deps, self._current_result_retry, tool_call) + if either.left: + result = either.left.value + else: + return either + return _utils.Either(left=result) + + def _incr_result_retry(self) -> None: + self._current_result_retry += 1 + if self._current_result_retry > self._max_result_retries: + raise RuntimeError(f'Exceeded maximum retries ({self._max_result_retries}) for result validation') + async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]: """Build the initial messages for the conversation.""" messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts] diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index f81dfc1d..8487fe22 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -12,7 +12,7 @@ from typing import Any, Literal from .. import _utils -from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall +from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall, ToolRetry, ToolReturn from . import AbstractToolDefinition, AgentModel, Model @@ -38,23 +38,23 @@ def agent_model( self, allow_text_result: bool, tools: list[AbstractToolDefinition], result_tool_name: str | None ) -> AgentModel: if self.call_retrievers == 'all': - retriever_calls = [(r.name, gen_retriever_args(r)) for r in tools if r.name != 'response'] + retriever_calls = [(r.name, r) for r in tools if r.name != 'response'] else: lookup = {r.name: r for r in tools} - retriever_calls = [(name, gen_retriever_args(lookup[name])) for name in self.call_retrievers] + retriever_calls = [(name, lookup[name]) for name in self.call_retrievers] if self.custom_result_text is not None: if not allow_text_result: raise ValueError('Plain response not allowed, but `custom_result_text` is set.') - result: _utils.Either[str, str] = _utils.Either(left=self.custom_result_text) + result: _utils.Either[str | None, AbstractToolDefinition] = _utils.Either(left=self.custom_result_text) elif self.custom_result_args is not None: assert result_tool_name is not None, 'No result tool name provided, but `custom_result_args` is set.' result = _utils.Either(right=self.custom_result_args) elif result_tool_name is not None: response_def = next(r for r in tools if r.name == result_tool_name) - result = _utils.Either(right=gen_retriever_args(response_def)) + result = _utils.Either(right=response_def) else: - result = _utils.Either(left='Final response') + result = _utils.Either(left=None) return TestAgentModel(retriever_calls, result, result_tool_name) @@ -63,31 +63,54 @@ class TestAgentModel(AgentModel): # NOTE: Avoid test discovery by pytest. __test__ = False - retriever_calls: list[tuple[str, str]] + retriever_calls: list[tuple[str, AbstractToolDefinition]] # left means the text is plain text, right means it's a function call - result: _utils.Either[str, str] + result: _utils.Either[str | None, AbstractToolDefinition] result_tool_name: str | None step: int = 0 + last_message_count: int = 0 async def request(self, messages: list[Message]) -> LLMMessage: if self.step == 0: + calls = [ + ToolCall(tool_name=name, arguments=self.gen_retriever_args(args)) for name, args in self.retriever_calls + ] self.step += 1 - return LLMToolCalls(calls=[ToolCall(tool_name=name, arguments=args) for name, args in self.retriever_calls]) - elif self.step == 1: + self.last_message_count = len(messages) + return LLMToolCalls(calls=calls) + + new_messages = messages[self.last_message_count :] + self.last_message_count = len(messages) + new_retry_names = {m.tool_name for m in new_messages if isinstance(m, ToolRetry)} + if new_retry_names: + calls = [ + ToolCall(tool_name=name, arguments=self.gen_retriever_args(args)) + for name, args in self.retriever_calls + if name in new_retry_names + ] self.step += 1 + return LLMToolCalls(calls=calls) + else: if response_text := self.result.left: - return LLMResponse(content=response_text) + self.step += 1 + if response_text.value is None: + # build up details of retriever responses + output: dict[str, str] = {} + for message in messages: + if isinstance(message, ToolReturn): + output[message.tool_name] = message.content + return LLMResponse(content=json.dumps(output)) + else: + return LLMResponse(content=response_text.value) else: assert self.result_tool_name is not None, 'No result tool name provided' - response_args = self.result.right + response_args = self.gen_retriever_args(self.result.right) + self.step += 1 return LLMToolCalls(calls=[ToolCall(tool_name=self.result_tool_name, arguments=response_args)]) - else: - raise ValueError('Invalid step') - -def gen_retriever_args(tool_def: AbstractToolDefinition) -> str: - """Generate arguments for a retriever.""" - return _JsonSchemaTestData(tool_def.json_schema).generate_json() + def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> str: + """Generate arguments for a retriever.""" + return _JsonSchemaTestData(tool_def.json_schema, self.step).generate_json() _chars = string.ascii_letters + string.digits + string.punctuation @@ -100,10 +123,10 @@ class _JsonSchemaTestData: This tries to generate the minimal viable data for the schema. """ - def __init__(self, schema: _utils.ObjectJsonSchema): + def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0): self.schema = schema self.defs = schema.get('$defs', {}) - self.seed = 0 + self.seed = seed def generate(self) -> Any: """Generate data for the JSON schema.""" diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 81165970..86503e61 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -1,12 +1,17 @@ from __future__ import annotations as _annotations +import inspect +from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar, Union, cast from pydantic import TypeAdapter, ValidationError from typing_extensions import Self, TypedDict from . import _utils, messages +from .retrievers import AgentDeps, CallContext, Retry + +ResultData = TypeVar('ResultData') @dataclass @@ -16,9 +21,6 @@ class Cost: total_cost: int -ResultData = TypeVar('ResultData') - - @dataclass class RunResult(Generic[ResultData]): """Result of a run.""" @@ -45,11 +47,9 @@ class ResultSchema(Generic[ResultData]): json_schema: _utils.ObjectJsonSchema allow_text_result: bool outer_typed_dict: bool - max_retries: int - _current_retry: int = 0 @classmethod - def build(cls, response_type: type[ResultData], name: str, description: str, retries: int) -> Self | None: + def build(cls, response_type: type[ResultData], name: str, description: str) -> Self | None: """Build a ResultSchema dataclass from a response type.""" if response_type is str: return None @@ -70,25 +70,82 @@ def build(cls, response_type: type[ResultData], name: str, description: str, ret json_schema=_utils.check_object_json_schema(type_adapter.json_schema()), allow_text_result=_utils.allow_plain_str(response_type), outer_typed_dict=outer_typed_dict, - max_retries=retries, ) - def validate(self, message: messages.ToolCall) -> _utils.Either[ResultData, messages.ToolRetry]: - """Validate a result message.""" + def validate(self, tool_call: messages.ToolCall) -> _utils.Either[ResultData, messages.ToolRetry]: + """Validate a result message. + + Returns: + Either the validated result data (left) or a retry message (right). + """ try: - result = self.type_adapter.validate_json(message.arguments) + result = self.type_adapter.validate_json(tool_call.arguments) except ValidationError as e: - self._current_retry += 1 - if self._current_retry > self.max_retries: - raise - else: - m = messages.ToolRetry( - tool_name=message.tool_name, - content=e.errors(), - tool_id=message.tool_id, - ) - return _utils.Either(right=m) + m = messages.ToolRetry( + tool_name=tool_call.tool_name, + content=e.errors(include_url=False), + tool_id=tool_call.tool_id, + ) + return _utils.Either(right=m) else: if self.outer_typed_dict: result = result['response'] return _utils.Either(left=result) + + +# A function that always takes `ResultData` and returns `ResultData`, +# but may or maybe not take `CallInfo` as a first argument, and may or may not be async. +# Usage `ResultValidator[AgentDeps, ResultData]` +ResultValidatorFunc = Union[ + Callable[[CallContext[AgentDeps], ResultData], ResultData], + Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], + Callable[[ResultData], ResultData], + Callable[[ResultData], Awaitable[ResultData]], +] + + +@dataclass +class ResultValidator(Generic[AgentDeps, ResultData]): + function: ResultValidatorFunc[AgentDeps, ResultData] + _takes_ctx: bool = False + _is_async: bool = False + + def __post_init__(self): + self._takes_ctx = len(inspect.signature(self.function).parameters) > 1 + self._is_async = inspect.iscoroutinefunction(self.function) + + async def validate( + self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall + ) -> _utils.Either[ResultData, messages.ToolRetry]: + """Validate a result but calling the function. + + Args: + result: The result data after Pydantic validation the message content. + deps: The agent dependencies. + retry: The current retry number. + tool_call: The original tool call message. + + Returns: + Either the validated result data (left) or a retry message (right). + """ + if self._takes_ctx: + args = CallContext(deps, retry), result + else: + args = (result,) + + try: + if self._is_async: + function = cast(Callable[[Any], Awaitable[ResultData]], self.function) + result_data = await function(*args) + else: + function = cast(Callable[[Any], ResultData], self.function) + result_data = await _utils.run_in_executor(function, *args) + except Retry as r: + m = messages.ToolRetry( + tool_name=tool_call.tool_name, + content=r.message, + tool_id=tool_call.tool_id, + ) + return _utils.Either(right=m) + else: + return _utils.Either(left=result_data) diff --git a/pydantic_ai/retrievers.py b/pydantic_ai/retrievers.py index 44d591fb..a73d8aae 100644 --- a/pydantic_ai/retrievers.py +++ b/pydantic_ai/retrievers.py @@ -83,7 +83,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes try: args_dict = self.validator.validate_json(message.arguments) except ValidationError as e: - return self._on_error(e.errors(), message) + return self._on_error(e.errors(include_url=False), message) args, kwargs = self._call_args(deps, args_dict) function = self.function.whichever() diff --git a/pyproject.toml b/pyproject.toml index 03603c7b..34760542 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,11 @@ reportUnnecessaryTypeIgnoreComment = true reportMissingTypeStubs = false include = ["pydantic_ai", "tests", "demos"] venvPath = ".venv" +# see https://github.com/microsoft/pyright/issues/7771, we don't want to error on decorated functions in tests +# which are not otherwise used +executionEnvironments = [ + { root = "tests", reportUnusedFunction = false }, +] [tool.pytest.ini_options] xfail_strict = true diff --git a/tests/test_function_model.py b/tests/test_function_model.py index f2a5f43b..dd489c36 100644 --- a/tests/test_function_model.py +++ b/tests/test_function_model.py @@ -1,14 +1,11 @@ import json from dataclasses import asdict -from datetime import datetime -from typing import TYPE_CHECKING, Any import pydantic_core import pytest from inline_snapshot import snapshot -from pydantic import BaseModel -from pydantic_ai import Agent, CallContext +from pydantic_ai import Agent, CallContext, Retry from pydantic_ai.messages import ( LLMMessage, LLMResponse, @@ -16,17 +13,13 @@ Message, SystemPrompt, ToolCall, + ToolRetry, ToolReturn, UserPrompt, ) from pydantic_ai.models.function import FunctionModel, ToolDescription from pydantic_ai.models.test import TestModel - -if TYPE_CHECKING: - - def IsNow(*args: Any, **kwargs: Any) -> datetime: ... -else: - from dirty_equals import IsNow +from tests.utils import IsNow def return_last( @@ -237,7 +230,7 @@ def test_deps_none(): agent = Agent(FunctionModel(call_retriever), deps=None) @agent.retriever_context - async def get_none(ctx: CallContext[None]): # pyright: ignore[reportUnusedFunction] + async def get_none(ctx: CallContext[None]): nonlocal called called = True @@ -274,37 +267,6 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str: assert called -def test_result_schema_tuple(): - def return_tuple(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: - assert len(retrievers) == 1 - retriever_key = next(iter(retrievers.keys())) - tuple_json = '{"response": ["foo", "bar"]}' - return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=tuple_json)]) - - agent = Agent(FunctionModel(return_tuple), deps=None, result_type=tuple[str, str]) - - result = agent.run_sync('Hello') - assert result.response == ('foo', 'bar') - - -def test_result_schema_pydantic_model(): - class Foo(BaseModel): - a: int - b: str - - def return_tuple(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: - assert len(retrievers) == 1 - retriever_key = next(iter(retrievers.keys())) - tuple_json = '{"a": 1, "b": "foo"}' - return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=tuple_json)]) - - agent = Agent(FunctionModel(return_tuple), deps=None, result_type=Foo) - - result = agent.run_sync('Hello') - assert isinstance(result.response, Foo) - assert result.response.model_dump() == {'a': 1, 'b': 'foo'} - - def test_model_arg(): agent = Agent(deps=None) result = agent.run_sync('Hello', model=FunctionModel(return_last)) @@ -359,7 +321,7 @@ def f(messages: list[Message], allow_text_result: bool, retrievers: dict[str, To def test_call_all(): result = agent_all.run_sync('Hello', model=TestModel()) - assert result.response == snapshot('Final response') + assert result.response == snapshot('{"foo": "1", "bar": "2", "baz": "3", "qux": "4", "quz": "a"}') assert result.message_history == snapshot( [ SystemPrompt(content='foobar'), @@ -379,7 +341,7 @@ def test_call_all(): ToolReturn(tool_name='baz', content='3', timestamp=IsNow()), ToolReturn(tool_name='qux', content='4', timestamp=IsNow()), ToolReturn(tool_name='quz', content='a', timestamp=IsNow()), - LLMResponse(content='Final response', timestamp=IsNow()), + LLMResponse(content='{"foo": "1", "bar": "2", "baz": "3", "qux": "4", "quz": "a"}', timestamp=IsNow()), ] ) @@ -420,3 +382,34 @@ def f(_messages: list[Message], _allow_text_result: bool, retrievers: dict[str, ) # description should be the first key assert next(iter(json_schema)) == 'description' + + +def test_retriever_retry(): + agent = Agent(deps=None) + call_count = 0 + + @agent.retriever_plain + async def my_ret(x: int) -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Retry('First call failed') + else: + return str(x + 1) + + result = agent.run_sync('Hello', model=TestModel()) + assert call_count == 2 + assert result.response == snapshot('{"my_ret": "2"}') + assert result.message_history == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow()), + LLMToolCalls( + calls=[ToolCall(tool_name='my_ret', arguments='{"x": 0}')], + timestamp=IsNow(), + ), + ToolRetry(tool_name='my_ret', content='First call failed', timestamp=IsNow()), + LLMToolCalls(calls=[ToolCall(tool_name='my_ret', arguments='{"x": 1}')], timestamp=IsNow()), + ToolReturn(tool_name='my_ret', content='2', timestamp=IsNow()), + LLMResponse(content='{"my_ret": "2"}', timestamp=IsNow()), + ] + ) diff --git a/tests/test_result_validation.py b/tests/test_result_validation.py new file mode 100644 index 00000000..6def6775 --- /dev/null +++ b/tests/test_result_validation.py @@ -0,0 +1,117 @@ +from inline_snapshot import snapshot +from pydantic import BaseModel + +from pydantic_ai import Agent, Retry +from pydantic_ai.messages import LLMMessage, LLMToolCalls, Message, ToolCall, ToolRetry, UserPrompt +from pydantic_ai.models.function import FunctionModel, ToolDescription +from tests.utils import IsNow + + +def test_result_tuple(): + def return_tuple(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + assert len(retrievers) == 1 + retriever_key = next(iter(retrievers.keys())) + args_json = '{"response": ["foo", "bar"]}' + return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=args_json)]) + + agent = Agent(FunctionModel(return_tuple), deps=None, result_type=tuple[str, str]) + + result = agent.run_sync('Hello') + assert result.response == ('foo', 'bar') + + +class Foo(BaseModel): + a: int + b: str + + +def test_result_pydantic_model(): + def return_model(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + assert len(retrievers) == 1 + retriever_key = next(iter(retrievers.keys())) + args_json = '{"a": 1, "b": "foo"}' + return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=args_json)]) + + agent = Agent(FunctionModel(return_model), deps=None, result_type=Foo) + + result = agent.run_sync('Hello') + assert isinstance(result.response, Foo) + assert result.response.model_dump() == {'a': 1, 'b': 'foo'} + + +def test_result_pydantic_model_retry(): + def return_model(messages: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + assert len(retrievers) == 1 + retriever_key = next(iter(retrievers.keys())) + if len(messages) == 1: + args_json = '{"a": "wrong", "b": "foo"}' + else: + args_json = '{"a": 42, "b": "foo"}' + return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=args_json)]) + + agent = Agent(FunctionModel(return_model), deps=None, result_type=Foo) + + result = agent.run_sync('Hello') + assert isinstance(result.response, Foo) + assert result.response.model_dump() == {'a': 42, 'b': 'foo'} + assert result.message_history == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow()), + LLMToolCalls( + calls=[ToolCall(tool_name='final_result', arguments='{"a": "wrong", "b": "foo"}')], + timestamp=IsNow(), + ), + ToolRetry( + tool_name='final_result', + content=[ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ], + timestamp=IsNow(), + ), + LLMToolCalls( + calls=[ToolCall(tool_name='final_result', arguments='{"a": 42, "b": "foo"}')], + timestamp=IsNow(), + ), + ] + ) + + +def test_result_validator(): + def return_model(messages: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + assert len(retrievers) == 1 + retriever_key = next(iter(retrievers.keys())) + if len(messages) == 1: + args_json = '{"a": 41, "b": "foo"}' + else: + args_json = '{"a": 42, "b": "foo"}' + return LLMToolCalls(calls=[ToolCall(tool_name=retriever_key, arguments=args_json)]) + + agent = Agent(FunctionModel(return_model), deps=None, result_type=Foo) + + @agent.result_validator + def validate_result(r: Foo) -> Foo: + if r.a == 42: + return r + else: + raise Retry('"a" should be 42') + + result = agent.run_sync('Hello') + assert isinstance(result.response, Foo) + assert result.response.model_dump() == {'a': 42, 'b': 'foo'} + assert result.message_history == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow()), + LLMToolCalls( + calls=[ToolCall(tool_name='final_result', arguments='{"a": 41, "b": "foo"}')], timestamp=IsNow() + ), + ToolRetry(tool_name='final_result', content='"a" should be 42', timestamp=IsNow()), + LLMToolCalls( + calls=[ToolCall(tool_name='final_result', arguments='{"a": 42, "b": "foo"}')], timestamp=IsNow() + ), + ] + ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..7768fd78 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,11 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + + def IsNow(*args: Any, **kwargs: Any) -> datetime: ... +else: + from dirty_equals import IsNow + + +__all__ = ('IsNow',)