Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add result_type to run methods #565

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

from . import _utils, messages as _messages
from .exceptions import ModelRetry
from .result import ResultData, ResultValidatorFunc
from .result import NewResultData, ResultData, ResultValidatorFunc
from .tools import AgentDeps, RunContext, ToolDefinition


@dataclass
class ResultValidator(Generic[AgentDeps, ResultData]):
function: ResultValidatorFunc[AgentDeps, ResultData]
class ResultValidator(Generic[AgentDeps, ResultData, NewResultData]):
function: ResultValidatorFunc[AgentDeps, ResultData | NewResultData]
_takes_ctx: bool = field(init=False)
_is_async: bool = field(init=False)

Expand All @@ -28,10 +28,10 @@ def __post_init__(self):

async def validate(
self,
result: ResultData,
result: ResultData | NewResultData,
tool_call: _messages.ToolCallPart | None,
run_context: RunContext[AgentDeps],
) -> ResultData:
) -> ResultData | NewResultData:
"""Validate a result but calling the function.

Args:
Expand All @@ -50,10 +50,10 @@ async def validate(

try:
if self._is_async:
function = cast(Callable[[Any], Awaitable[ResultData]], self.function)
function = cast(Callable[[Any], Awaitable[ResultData | NewResultData]], self.function)
result_data = await function(*args)
else:
function = cast(Callable[[Any], ResultData], self.function)
function = cast(Callable[[Any], ResultData | NewResultData], self.function)
result_data = await _utils.run_in_executor(function, *args)
except ModelRetry as r:
m = _messages.RetryPromptPart(content=r.message)
Expand All @@ -74,17 +74,17 @@ def __init__(self, tool_retry: _messages.RetryPromptPart):


@dataclass
class ResultSchema(Generic[ResultData]):
class ResultSchema(Generic[ResultData, NewResultData]):
"""Model the final response from an agent run.

Similar to `Tool` but for the final result of running an agent.
"""

tools: dict[str, ResultTool[ResultData]]
tools: dict[str, ResultTool[ResultData | NewResultData]]
allow_text_result: bool

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str | None) -> Self | None:
def build(cls, response_type: type[ResultData | NewResultData], name: str, description: str | None) -> Self | None:
"""Build a ResultSchema dataclass from a response type."""
if response_type is str:
return None
Expand All @@ -95,10 +95,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
else:
allow_text_result = False

def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData | NewResultData]:
return cast(ResultTool[ResultData | NewResultData], ResultTool(a, tool_name_, description, multiple))

tools: dict[str, ResultTool[ResultData]] = {}
tools: dict[str, ResultTool[ResultData | NewResultData]] = {}
if args := get_union_args(response_type):
for i, arg in enumerate(args, start=1):
tool_name = union_tool_name(name, arg)
Expand All @@ -112,7 +112,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat

def find_named_tool(
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData | NewResultData]] | None:
"""Find a tool that matches one of the calls, with a specific name."""
for part in parts:
if isinstance(part, _messages.ToolCallPart):
Expand All @@ -122,7 +122,7 @@ def find_named_tool(
def find_tool(
self,
parts: Iterable[_messages.ModelResponsePart],
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData | NewResultData]] | None:
"""Find a tool that matches one of the calls."""
for part in parts:
if isinstance(part, _messages.ToolCallPart):
Expand All @@ -142,11 +142,13 @@ def tool_defs(self) -> list[ToolDefinition]:


@dataclass(init=False)
class ResultTool(Generic[ResultData]):
class ResultTool(Generic[ResultData, NewResultData]):
tool_def: ToolDefinition
type_adapter: TypeAdapter[Any]

def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
def __init__(
self, response_type: type[ResultData | NewResultData], name: str, description: str | None, multiple: bool
):
"""Build a ResultTool dataclass from a response type."""
assert response_type is not str, 'ResultTool does not support str as a response type'

Expand Down Expand Up @@ -183,7 +185,7 @@ def __init__(self, response_type: type[ResultData], name: str, description: str

def validate(
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
) -> ResultData:
) -> ResultData | NewResultData:
"""Validate a result message.

Args:
Expand Down
Loading
Loading