Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename modules, use exception not Either for return #5

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .agent import Agent
from .retrievers import CallContext, Retry
from .call import CallContext, Retry

__all__ = 'Agent', 'CallContext', 'Retry'
7 changes: 4 additions & 3 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like

if TYPE_CHECKING:
from . import retrievers as _r
from . import _retriever
from .call import AgentDeps


__all__ = ('function_schema',)
Expand All @@ -39,7 +40,7 @@ class FunctionSchema(TypedDict):
var_positional_field: str | None


def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P]) -> FunctionSchema:
def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, _retriever.P]) -> FunctionSchema:
"""Build a Pydantic validator and JSON schema from a retriever function.

Args:
Expand Down Expand Up @@ -227,7 +228,7 @@ def _infer_docstring_style(doc: str) -> DocstringStyle:


def _is_call_ctx(annotation: Any) -> bool:
from .retrievers import CallContext
from .call import CallContext

return annotation is CallContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
Expand Down
34 changes: 7 additions & 27 deletions pydantic_ai/retrievers.py → pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,20 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Generic, TypeVar, Union, cast
from typing import Any, Callable, Generic, Union, cast

import pydantic_core
from pydantic import ValidationError
from pydantic_core import SchemaValidator
from typing_extensions import Concatenate, ParamSpec

from . import _pydantic, _utils, messages
from .call import AgentDeps, CallContext, Retry

AgentDeps = TypeVar('AgentDeps')
# retrieval function parameters
P = ParamSpec('P')


@dataclass
class CallContext(Generic[AgentDeps]):
"""Information about the current call."""

deps: AgentDeps
# do we allow retries within functions?
retry: int


# Usage `RetrieverContextFunc[AgentDependencies, P]`
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], P], Union[str, Awaitable[str]]]
# Usage `RetrieverPlainFunc[P]`
Expand All @@ -34,14 +25,6 @@ class CallContext(Generic[AgentDeps]):
RetrieverEitherFunc = _utils.Either[RetrieverContextFunc[AgentDeps, P], RetrieverPlainFunc[P]]


class Retry(Exception):
"""Exception raised when a retriever function should be retried."""

def __init__(self, message: str):
self.message = message
super().__init__(message)


@dataclass(init=False)
class Retriever(Generic[AgentDeps, P]):
"""A retriever function for an agent."""
Expand Down Expand Up @@ -86,23 +69,20 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes
return self._on_error(e.errors(include_url=False), message)

args, kwargs = self._call_args(deps, args_dict)
function = self.function.whichever()
try:
if self.is_async:
response_content = await function(*args, **kwargs) # pyright: ignore[reportCallIssue,reportUnknownVariableType,reportGeneralTypeIssues]
function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
response_content = await function(*args, **kwargs)
else:
response_content = await _utils.run_in_executor(
function, # pyright: ignore[reportArgumentType,reportCallIssue]
*args,
**kwargs,
)
function = cast(Callable[[Any], str], self.function.whichever())
response_content = await _utils.run_in_executor(function, *args, **kwargs)
except Retry as e:
return self._on_error(e.message, message)

self._current_retry = 0
return messages.ToolReturn(
tool_name=message.tool_name,
content=cast(str, response_content),
content=response_content,
tool_id=message.tool_id,
)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Generic, Union, cast

from . import _utils
from .retrievers import AgentDeps, CallContext
from .call import AgentDeps, CallContext

# A function that may or maybe not take `CallInfo` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDeps]`
Expand Down
27 changes: 13 additions & 14 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,52 +86,51 @@ class Some(Generic[_T]):
Option: TypeAlias = Union[Some[_T], None]


_Left = TypeVar('_Left')
_Right = TypeVar('_Right')
Left = TypeVar('Left')
Right = TypeVar('Right')


class Either(Generic[_Left, _Right]):
class Either(Generic[Left, Right]):
"""Two member Union that records which member was set, this is analogous to Rust enums with two variants.

Usage:

```py
if left_thing := either.left:
use_left(left_thing)
use_left(left_thing.value)
else:
use_right(either.right)
```
"""

__slots__ = '_left', '_right'

@overload
def __init__(self, *, left: _Left) -> None: ...
def __init__(self, *, left: Left) -> None: ...

@overload
def __init__(self, *, right: _Right) -> None: ...
def __init__(self, *, right: Right) -> None: ...

def __init__(self, **kwargs: Any) -> None:
keys = set(kwargs.keys())
if keys == {'left'}:
self._left: Option[_Left] = Some(kwargs['left'])
self._right: _Right | None = None
self._left: Option[Left] = Some(kwargs['left'])
elif keys == {'right'}:
self._left = None
self._right = kwargs['right']
else:
raise TypeError('Either must receive exactly one value - `left` or `right`')
raise TypeError('Either must receive exactly one argument - `left` or `right`')

@property
def left(self) -> Option[_Left]:
def left(self) -> Option[Left]:
return self._left

@property
def right(self) -> _Right:
if self._right is None:
raise TypeError('Right not set')
def right(self) -> Right:
return self._right

def is_left(self) -> bool:
return self._left is not None

def whichever(self) -> _Left | _Right:
def whichever(self) -> Left | Right:
return self._left.value if self._left is not None else self.right
38 changes: 15 additions & 23 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

from typing_extensions import assert_never

from . import _system_prompt, _utils, messages as _messages, models as _models, result as _result, retrievers as _r
from .models import Model
from . import _retriever as _r, _system_prompt, _utils, messages as _messages, models as _models, result as _result
from .call import AgentDeps
from .result import ResultData
from .retrievers import AgentDeps

__all__ = ('Agent',)
KnownModelName = Literal['openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo']
Expand All @@ -21,7 +20,7 @@ class Agent(Generic[AgentDeps, ResultData]):
"""Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM."""

# slots mostly for my sanity — knowing what attributes are available
_model: Model | None
_model: _models.Model | None
_result_tool: _result.ResultSchema[ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
_allow_text_result: bool
Expand Down Expand Up @@ -198,7 +197,7 @@ def retriever_decorator(func_: _r.RetrieverPlainFunc[_r.P]) -> _r.Retriever[Agen
return self._register_retriever(_utils.Either(right=func), retries)

def _register_retriever(
self, func: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P], retries: int | None
self, func: _r.RetrieverEitherFunc[AgentDeps, _r.P], retries: int | None
) -> _r.Retriever[AgentDeps, _r.P]:
"""Private utility to register a retriever function."""
retries_ = retries if retries is not None else self._default_retries
Expand Down Expand Up @@ -234,18 +233,17 @@ async def _handle_model_response(
# NOTE: this means we ignore any other tools called here
call = next((c for c in llm_message.calls if c.tool_name == self._result_tool.name), None)
if call is not None:
either = self._result_tool.validate(call)
if result_data := either.left:
either = await self._validate_result(result_data.value, deps, call)

if result_data := either.left:
return _utils.Some(result_data.value)
else:
try:
result = self._result_tool.validate(call)
result = await self._validate_result(result, deps, call)
except _result.ToolRetryError as e:
self._incr_result_retry()
messages.append(either.right)
messages.append(e.tool_retry)
return None
else:
return _utils.Some(result)

# otherwise we run all functions in parallel
# otherwise we run all retriever functions in parallel
coros: list[Awaitable[_messages.Message]] = []
for call in llm_message.calls:
retriever = self._retrievers.get(call.tool_name)
Expand All @@ -257,16 +255,10 @@ async def _handle_model_response(
else:
assert_never(llm_message)

async def _validate_result(
self, result: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall
) -> _utils.Either[ResultData, _messages.ToolRetry]:
async def _validate_result(self, result: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall) -> ResultData:
for validator in self._result_validators:
either = await validator.validate(result, deps, self._current_result_retry, tool_call)
if either.left:
result = either.left.value
else:
return either
return _utils.Either(left=result)
result = await validator.validate(result, deps, self._current_result_retry, tool_call)
return result

def _incr_result_retry(self) -> None:
self._current_result_retry += 1
Expand Down
22 changes: 22 additions & 0 deletions pydantic_ai/call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Generic, TypeVar

AgentDeps = TypeVar('AgentDeps')


@dataclass
class CallContext(Generic[AgentDeps]):
"""Information about the current call."""

deps: AgentDeps
retry: int


class Retry(Exception):
"""Exception raised when a retriever function should be retried."""

def __init__(self, message: str):
self.message = message
super().__init__(message)
22 changes: 14 additions & 8 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Self, TypedDict

from . import _utils, messages
from .retrievers import AgentDeps, CallContext, Retry
from .call import AgentDeps, CallContext, Retry

ResultData = TypeVar('ResultData')

Expand Down Expand Up @@ -72,7 +72,7 @@ def build(cls, response_type: type[ResultData], name: str, description: str) ->
outer_typed_dict=outer_typed_dict,
)

def validate(self, tool_call: messages.ToolCall) -> _utils.Either[ResultData, messages.ToolRetry]:
def validate(self, tool_call: messages.ToolCall) -> ResultData:
"""Validate a result message.

Returns:
Expand All @@ -86,11 +86,11 @@ def validate(self, tool_call: messages.ToolCall) -> _utils.Either[ResultData, me
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
)
return _utils.Either(right=m)
raise ToolRetryError(m) from e
else:
if self.outer_typed_dict:
result = result['response']
return _utils.Either(left=result)
return result


# A function that always takes `ResultData` and returns `ResultData`,
Expand All @@ -116,7 +116,7 @@ def __post_init__(self):

async def validate(
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall
) -> _utils.Either[ResultData, messages.ToolRetry]:
) -> ResultData:
"""Validate a result but calling the function.

Args:
Expand All @@ -126,7 +126,7 @@ async def validate(
tool_call: The original tool call message.

Returns:
Either the validated result data (left) or a retry message (right).
Result of either the validated result data (ok) or a retry message (Err).
"""
if self._takes_ctx:
args = CallContext(deps, retry), result
Expand All @@ -146,6 +146,12 @@ async def validate(
content=r.message,
tool_id=tool_call.tool_id,
)
return _utils.Either(right=m)
raise ToolRetryError(m) from r
else:
return _utils.Either(left=result_data)
return result_data


class ToolRetryError(Exception):
def __init__(self, tool_retry: messages.ToolRetry):
self.tool_retry = tool_retry
super().__init__()