Skip to content

Commit

Permalink
allow multiple result tools for unions (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 28, 2024
1 parent 7febaae commit 0042877
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 194 deletions.
15 changes: 15 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ test:
uv run coverage run -m pytest
@uv run coverage report

.PHONY: test-all-python # Run tests on Python 3.9 to 3.13
test-all-python:
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 coverage run -m pytest
@mv .coverage .coverage.3.9
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 coverage run -m pytest
@mv .coverage .coverage.3.10
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 coverage run -m pytest
@mv .coverage .coverage.3.11
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 coverage run -m pytest
@mv .coverage .coverage.3.12
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 coverage run -m pytest
@mv .coverage .coverage.3.13
@uv run coverage combine
@uv run coverage report

.PHONY: testcov # Run tests and generate a coverage report
testcov: test
@echo "building coverage html"
Expand Down
7 changes: 2 additions & 5 deletions examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
from __future__ import annotations as _annotations

import asyncio
import os
import re
import sys
import unicodedata
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import cast

import asyncpg
import httpx
Expand All @@ -37,7 +35,7 @@
from typing_extensions import AsyncGenerator

from pydantic_ai import CallContext
from pydantic_ai.agent import Agent, KnownModelName
from pydantic_ai.agent import Agent

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure()
Expand All @@ -50,8 +48,7 @@ class Deps:
pool: asyncpg.Pool


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
agent: Agent[Deps, str] = Agent(model)
agent: Agent[Deps, str] = Agent('openai:gpt-4o')


@agent.retriever_context
Expand Down
43 changes: 29 additions & 14 deletions examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
Run with:
uv run --extra examples -m examples.sql_gen
uv run --extra examples -m examples.sql_gen "show me logs from yesterday, with level 'error'"
"""

import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import date
from typing import Annotated, Any, cast
from typing import Annotated, Any, Union

import asyncpg
import logfire
from annotated_types import MinLen
from devtools import debug
from pydantic import BaseModel, Field
from typing_extensions import TypeAlias

from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai.agent import KnownModelName

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure()
Expand Down Expand Up @@ -48,17 +48,29 @@


@dataclass
class Response:
class Deps:
conn: asyncpg.Connection


class Success(BaseModel):
"""Response when SQL could be successfully generated."""

sql_query: Annotated[str, MinLen(1)]
explanation: str = Field(None, description='Explanation of the SQL query, as markdown')


@dataclass
class Deps:
conn: asyncpg.Connection
class InvalidRequest(BaseModel):
"""Response the user input didn't include enough information to generate SQL."""

error_message: str


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash'))
agent: Agent[Deps, Response] = Agent(model, result_type=Response)
Response: TypeAlias = Union[Success, InvalidRequest]
agent: Agent[Deps, Response] = Agent(
'gemini-1.5-flash',
# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
result_type=Response, # type: ignore
)


@agent.system_prompt
Expand Down Expand Up @@ -87,10 +99,13 @@ async def system_prompt() -> str:

@agent.result_validator
async def validate_result(ctx: CallContext[Deps], result: Response) -> Response:
if isinstance(result, InvalidRequest):
return result

# gemini often adds extraneous backslashes to SQL
result.sql_query = result.sql_query.replace('\\', '')
lower_query = result.sql_query.lower()
if not lower_query.startswith('select'):
raise ModelRetry('Please a SELECT query')
if not result.sql_query.upper().startswith('SELECT'):
raise ModelRetry('Please create a SELECT query')

try:
await ctx.deps.conn.execute(f'EXPLAIN {result.sql_query}')
Expand All @@ -109,7 +124,7 @@ async def main():
async with database_connect('postgresql://postgres@localhost', 'pydantic_ai_sql_gen') as conn:
deps = Deps(conn)
result = await agent.run(prompt, deps=deps)
debug(result.response.sql_query)
debug(result.response)


# pyright: reportUnknownMemberType=false
Expand Down
8 changes: 4 additions & 4 deletions examples/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, cast
from typing import Any

import logfire
from devtools import debug
from httpx import AsyncClient

from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai.agent import KnownModelName

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure(send_to_logfire='if-token-present')
Expand All @@ -32,8 +31,9 @@ class Deps:
geo_api_key: str | None


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
weather_agent: Agent[Deps, str] = Agent(model, system_prompt='Be concise, reply with one sentence.', retries=2)
weather_agent: Agent[Deps, str] = Agent(
'openai:gpt-4o', system_prompt='Be concise, reply with one sentence.', retries=2
)


@weather_agent.retriever_context
Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic import ConfigDict, TypeAdapter
from pydantic._internal import _decorators, _generate_schema, _typing_extra
from pydantic._internal._config import ConfigWrapper
from pydantic._internal._typing_extra import origin_is_union
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema
from pydantic.plugin._schema_validator import create_schema_validator
Expand All @@ -26,7 +27,12 @@
from .shared import AgentDeps


__all__ = 'function_schema', 'LazyTypeAdapter'
__all__ = 'function_schema', 'LazyTypeAdapter', 'is_union'


def is_union(tp: Any) -> bool:
origin = get_origin(tp)
return origin_is_union(origin)


class FunctionSchema(TypedDict):
Expand Down
109 changes: 94 additions & 15 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Union, cast
from typing import Any, Callable, Generic, Union, cast, get_args

from pydantic import TypeAdapter, ValidationError
from typing_extensions import Self, TypedDict

from . import _utils, messages
from . import _pydantic, _utils, messages
from .messages import LLMToolCalls, ToolCall
from .shared import AgentDeps, CallContext, ModelRetry, ResultData

# A function that always takes `ResultData` and returns `ResultData`,
Expand Down Expand Up @@ -83,42 +84,87 @@ class ResultSchema(Generic[ResultData]):
Similar to `Retriever` but for the final result of running an agent.
"""

tools: dict[str, ResultTool[ResultData]]
allow_text_result: bool

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str | None) -> Self | None:
"""Build a ResultSchema dataclass from a response type."""
if response_type is str:
return None

if response_type_option := extract_str_from_union(response_type):
response_type = response_type_option.value
allow_text_result = True
else:
allow_text_result = False

def _build_tool(a: Any, tool_name: str, multiple: bool) -> ResultTool[ResultData]:
return cast(
ResultTool[ResultData],
ResultTool.build(a, tool_name, description, multiple), # pyright: ignore[reportUnknownMemberType]
)

tools: dict[str, ResultTool[ResultData]] = {}
if args := union_args(response_type):
for arg in args:
tool_name = union_tool_name(name, arg)
tools[tool_name] = _build_tool(arg, tool_name, True)
else:
tools[name] = _build_tool(response_type, name, False)

return cls(tools=tools, allow_text_result=allow_text_result)

def find_tool(self, message: LLMToolCalls) -> tuple[ToolCall, ResultTool[ResultData]] | None:
"""Find a tool that matches one of the calls."""
for call in message.calls:
if result := self.tools.get(call.tool_name):
return call, result


DEFAULT_DESCRIPTION = 'The final response which ends this conversation'


@dataclass
class ResultTool(Generic[ResultData]):
name: str
description: str
type_adapter: TypeAdapter[Any]
json_schema: _utils.ObjectJsonSchema
allow_text_result: bool
outer_typed_dict_key: str | None

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str) -> Self | None:
"""Build a ResultSchema dataclass from a response type."""
if response_type is str:
return None
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
"""Build a ResultTool dataclass from a response type."""
assert response_type is not str, 'ResultTool does not support str as a response type'

allow_text_result = False
if _utils.is_model_like(response_type):
type_adapter = TypeAdapter(response_type)
outer_typed_dict_key: str | None = None
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
else:
if response_type_option := _utils.extract_str_from_union(response_type):
response_type = response_type_option.value
allow_text_result = True

response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
type_adapter = TypeAdapter(response_data_typed_dict)
outer_typed_dict_key = 'response'
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
json_schema.pop('title') # pyright: ignore[reportCallIssue,reportArgumentType]
json_schema.pop('title')

if json_schema_description := json_schema.pop('description', None):
if description is None:
tool_description = json_schema_description
else:
tool_description = f'{description}. {json_schema_description}'
else:
tool_description = description or DEFAULT_DESCRIPTION
if multiple:
tool_description = f'{union_arg_name(response_type)}: {tool_description}'

return cls(
name=name,
description=description,
description=tool_description,
type_adapter=type_adapter,
json_schema=json_schema,
allow_text_result=allow_text_result,
outer_typed_dict_key=outer_typed_dict_key,
)

Expand All @@ -144,3 +190,36 @@ def validate(self, tool_call: messages.ToolCall) -> ResultData:
if k := self.outer_typed_dict_key:
result = result[k]
return result


def union_tool_name(base_name: str, union_arg: Any) -> str:
return f'{base_name}_{union_arg_name(union_arg)}'


def union_arg_name(union_arg: Any) -> str:
return union_arg.__name__


def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
"""Extract the string type from a Union, return the remaining union or remaining type."""
if _pydantic.is_union(response_type) and any(t is str for t in get_args(response_type)):
remain_args: list[Any] = []
includes_str = False
for arg in get_args(response_type):
if arg is str:
includes_str = True
else:
remain_args.append(arg)
if includes_str:
if len(remain_args) == 1:
return _utils.Some(remain_args[0])
else:
return _utils.Some(Union[tuple(remain_args)])


def union_args(response_type: Any) -> tuple[Any, ...]:
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
if _pydantic.is_union(response_type):
return get_args(response_type)
else:
return ()
Loading

0 comments on commit 0042877

Please sign in to comment.