Skip to content

Commit

Permalink
validate_result and test Retry (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 13, 2024
1 parent c018054 commit 9e191df
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 129 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ __pycache__
/scratch/
/.coverage
env*/
/TODO.md
37 changes: 18 additions & 19 deletions pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Callable, Generic, Union
from typing import Any, Callable, Generic, Union, cast

from . import _utils, retrievers as _r
from .retrievers import AgentDeps
from . import _utils
from .retrievers import AgentDeps, CallContext

# This is basically a function that may or maybe not take `CallInfo` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDependencies]`
# A function that may or maybe not take `CallInfo` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDeps]`
SystemPromptFunc = Union[
Callable[[_r.CallContext[AgentDeps]], str],
Callable[[_r.CallContext[AgentDeps]], Awaitable[str]],
Callable[[CallContext[AgentDeps]], str],
Callable[[CallContext[AgentDeps]], Awaitable[str]],
Callable[[], str],
Callable[[], Awaitable[str]],
]
Expand All @@ -21,23 +21,22 @@
@dataclass
class SystemPromptRunner(Generic[AgentDeps]):
function: SystemPromptFunc[AgentDeps]
takes_ctx: bool = False
is_async: bool = False
_takes_ctx: bool = False
_is_async: bool = False

def __post_init__(self):
self.takes_ctx = len(inspect.signature(self.function).parameters) > 0
self.is_async = inspect.iscoroutinefunction(self.function)
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
self._is_async = inspect.iscoroutinefunction(self.function)

async def run(self, deps: AgentDeps) -> str:
if self.takes_ctx:
args = (_r.CallContext(deps, 0),)
if self._takes_ctx:
args = (CallContext(deps, 0),)
else:
args = ()

if self.is_async:
return await self.function(*args) # pyright: ignore[reportGeneralTypeIssues,reportUnknownVariableType]
if self._is_async:
function = cast(Callable[[Any], Awaitable[str]], self.function)
return await function(*args)
else:
return await _utils.run_in_executor(
self.function, # pyright: ignore[reportArgumentType,reportReturnType]
*args,
)
function = cast(Callable[[Any], str], self.function)
return await _utils.run_in_executor(function, *args)
21 changes: 13 additions & 8 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class Some(Generic[_T]):


class Either(Generic[_Left, _Right]):
"""Two member Union that records which member was set.
"""Two member Union that records which member was set, this is analogous to Rust enums with two variants.
Usage:
Expand All @@ -109,14 +109,19 @@ def __init__(self, *, left: _Left) -> None: ...
@overload
def __init__(self, *, right: _Right) -> None: ...

def __init__(self, *, left: _Left | None = None, right: _Right | None = None) -> None:
if (left is not None and right is not None) or (left is None and right is None):
raise TypeError('Either must have exactly one value')
self._left = left
self._right = right
def __init__(self, **kwargs: Any) -> None:
keys = set(kwargs.keys())
if keys == {'left'}:
self._left: Option[_Left] = Some(kwargs['left'])
self._right: _Right | None = None
elif keys == {'right'}:
self._left = None
self._right = kwargs['right']
else:
raise TypeError('Either must receive exactly one value - `left` or `right`')

@property
def left(self) -> _Left | None:
def left(self) -> Option[_Left]:
return self._left

@property
Expand All @@ -129,4 +134,4 @@ def is_left(self) -> bool:
return self._left is not None

def whichever(self) -> _Left | _Right:
return self._left or self.right
return self._left.value if self._left is not None else self.right
63 changes: 47 additions & 16 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,36 @@

import asyncio
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, overload

from typing_extensions import assert_never

from . import _system_prompt, _utils, messages as _messages, models as _models, result as _result, retrievers as _r
from .models import Model
from .result import ResultData
from .retrievers import AgentDeps

__all__ = ('Agent',)
KnownModelName = Literal['openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo']


@dataclass(init=False)
class Agent(Generic[AgentDeps, ResultData]):
"""Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM."""

# slots mostly for my sanity — knowing what attributes are available
__slots__ = (
'_model',
'_result_tool',
'_allow_text_result',
'_system_prompts',
'_retrievers',
'_default_retries',
'_system_prompt_functions',
'_default_deps',
)
_model: Model | None
_result_tool: _result.ResultSchema[ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
_allow_text_result: bool
_system_prompts: tuple[str, ...]
_retrievers: dict[str, _r.Retriever[AgentDeps, Any]]
_default_retries: int
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]]
_default_deps: AgentDeps
_max_result_retries: int
_current_result_retry: int

def __init__(
self,
Expand All @@ -45,10 +49,7 @@ def __init__(
self._model = _models.infer_model(model) if model is not None else None

self._result_tool = _result.ResultSchema[result_type].build(
result_type,
result_tool_name,
result_tool_description,
result_retries if result_retries is not None else retries,
result_type, result_tool_name, result_tool_description
)
# if the result tool is None, or its schema allows `str`, we allow plain text results
self._allow_text_result = self._result_tool is None or self._result_tool.allow_text_result
Expand All @@ -57,7 +58,10 @@ def __init__(
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
self._default_deps = cast(AgentDeps, None if deps == () else deps)
self._default_retries = retries
self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = []
self._system_prompt_functions = []
self._max_result_retries = result_retries if result_retries is not None else retries
self._current_result_retry = 0
self._result_validators = []

async def run(
self,
Expand Down Expand Up @@ -142,6 +146,13 @@ def system_prompt(
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func

def result_validator(
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData]
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
"""Decorator to register a result validator function."""
self._result_validators.append(_result.ResultValidator(func))
return func

@overload
def retriever_context(self, func: _r.RetrieverContextFunc[AgentDeps, _r.P], /) -> _r.Retriever[AgentDeps, _r.P]: ...

Expand Down Expand Up @@ -225,8 +236,12 @@ async def _handle_model_response(
if call is not None:
either = self._result_tool.validate(call)
if result_data := either.left:
return _utils.Some(result_data)
either = await self._validate_result(result_data.value, deps, call)

if result_data := either.left:
return _utils.Some(result_data.value)
else:
self._incr_result_retry()
messages.append(either.right)
return None

Expand All @@ -242,6 +257,22 @@ async def _handle_model_response(
else:
assert_never(llm_message)

async def _validate_result(
self, result: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall
) -> _utils.Either[ResultData, _messages.ToolRetry]:
for validator in self._result_validators:
either = await validator.validate(result, deps, self._current_result_retry, tool_call)
if either.left:
result = either.left.value
else:
return either
return _utils.Either(left=result)

def _incr_result_retry(self) -> None:
self._current_result_retry += 1
if self._current_result_retry > self._max_result_retries:
raise RuntimeError(f'Exceeded maximum retries ({self._max_result_retries}) for result validation')

async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
"""Build the initial messages for the conversation."""
messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
Expand Down
63 changes: 43 additions & 20 deletions pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Literal

from .. import _utils
from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall
from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall, ToolRetry, ToolReturn
from . import AbstractToolDefinition, AgentModel, Model


Expand All @@ -38,23 +38,23 @@ def agent_model(
self, allow_text_result: bool, tools: list[AbstractToolDefinition], result_tool_name: str | None
) -> AgentModel:
if self.call_retrievers == 'all':
retriever_calls = [(r.name, gen_retriever_args(r)) for r in tools if r.name != 'response']
retriever_calls = [(r.name, r) for r in tools if r.name != 'response']
else:
lookup = {r.name: r for r in tools}
retriever_calls = [(name, gen_retriever_args(lookup[name])) for name in self.call_retrievers]
retriever_calls = [(name, lookup[name]) for name in self.call_retrievers]

if self.custom_result_text is not None:
if not allow_text_result:
raise ValueError('Plain response not allowed, but `custom_result_text` is set.')
result: _utils.Either[str, str] = _utils.Either(left=self.custom_result_text)
result: _utils.Either[str | None, AbstractToolDefinition] = _utils.Either(left=self.custom_result_text)
elif self.custom_result_args is not None:
assert result_tool_name is not None, 'No result tool name provided, but `custom_result_args` is set.'
result = _utils.Either(right=self.custom_result_args)
elif result_tool_name is not None:
response_def = next(r for r in tools if r.name == result_tool_name)
result = _utils.Either(right=gen_retriever_args(response_def))
result = _utils.Either(right=response_def)
else:
result = _utils.Either(left='Final response')
result = _utils.Either(left=None)
return TestAgentModel(retriever_calls, result, result_tool_name)


Expand All @@ -63,31 +63,54 @@ class TestAgentModel(AgentModel):
# NOTE: Avoid test discovery by pytest.
__test__ = False

retriever_calls: list[tuple[str, str]]
retriever_calls: list[tuple[str, AbstractToolDefinition]]
# left means the text is plain text, right means it's a function call
result: _utils.Either[str, str]
result: _utils.Either[str | None, AbstractToolDefinition]
result_tool_name: str | None
step: int = 0
last_message_count: int = 0

async def request(self, messages: list[Message]) -> LLMMessage:
if self.step == 0:
calls = [
ToolCall(tool_name=name, arguments=self.gen_retriever_args(args)) for name, args in self.retriever_calls
]
self.step += 1
return LLMToolCalls(calls=[ToolCall(tool_name=name, arguments=args) for name, args in self.retriever_calls])
elif self.step == 1:
self.last_message_count = len(messages)
return LLMToolCalls(calls=calls)

new_messages = messages[self.last_message_count :]
self.last_message_count = len(messages)
new_retry_names = {m.tool_name for m in new_messages if isinstance(m, ToolRetry)}
if new_retry_names:
calls = [
ToolCall(tool_name=name, arguments=self.gen_retriever_args(args))
for name, args in self.retriever_calls
if name in new_retry_names
]
self.step += 1
return LLMToolCalls(calls=calls)
else:
if response_text := self.result.left:
return LLMResponse(content=response_text)
self.step += 1
if response_text.value is None:
# build up details of retriever responses
output: dict[str, str] = {}
for message in messages:
if isinstance(message, ToolReturn):
output[message.tool_name] = message.content
return LLMResponse(content=json.dumps(output))
else:
return LLMResponse(content=response_text.value)
else:
assert self.result_tool_name is not None, 'No result tool name provided'
response_args = self.result.right
response_args = self.gen_retriever_args(self.result.right)
self.step += 1
return LLMToolCalls(calls=[ToolCall(tool_name=self.result_tool_name, arguments=response_args)])
else:
raise ValueError('Invalid step')


def gen_retriever_args(tool_def: AbstractToolDefinition) -> str:
"""Generate arguments for a retriever."""
return _JsonSchemaTestData(tool_def.json_schema).generate_json()
def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> str:
"""Generate arguments for a retriever."""
return _JsonSchemaTestData(tool_def.json_schema, self.step).generate_json()


_chars = string.ascii_letters + string.digits + string.punctuation
Expand All @@ -100,10 +123,10 @@ class _JsonSchemaTestData:
This tries to generate the minimal viable data for the schema.
"""

def __init__(self, schema: _utils.ObjectJsonSchema):
def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
self.schema = schema
self.defs = schema.get('$defs', {})
self.seed = 0
self.seed = seed

def generate(self) -> Any:
"""Generate data for the JSON schema."""
Expand Down
Loading

0 comments on commit 9e191df

Please sign in to comment.