From 7ccd5d2abf3ae74cb003c33ede135d04ce3b70c1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 12 Oct 2024 22:28:35 +0100 Subject: [PATCH 1/2] auto generate retriever args for testing --- Makefile | 1 + demos/parse_model.py | 4 +- pydantic_ai/_utils.py | 15 +- pydantic_ai/agent.py | 10 +- pydantic_ai/models/__init__.py | 16 ++- pydantic_ai/models/function.py | 24 ++-- pydantic_ai/models/openai.py | 14 +- pydantic_ai/models/test.py | 247 +++++++++++++++++++++++++++++++++ tests/test_function_model.py | 53 +++++-- tests/test_testing.py | 76 ++++++++++ 10 files changed, 411 insertions(+), 49 deletions(-) create mode 100644 pydantic_ai/models/test.py create mode 100644 tests/test_testing.py diff --git a/Makefile b/Makefile index b05500c5..835ac7db 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,7 @@ test: testcov: test @echo "building coverage html" @uv run coverage html --show-contexts + @uv run coverage report .PHONY: all all: format lint typecheck test diff --git a/demos/parse_model.py b/demos/parse_model.py index e45bbe19..2cd88de7 100644 --- a/demos/parse_model.py +++ b/demos/parse_model.py @@ -1,4 +1,3 @@ -from devtools import debug from pydantic import BaseModel from pydantic_ai import Agent @@ -11,7 +10,6 @@ class MyModel(BaseModel): agent = Agent('openai:gpt-4o', response_type=MyModel, deps=None) -# debug(agent.result_schema.json_schema) result = agent.run_sync('The windy city in the US of A.') -debug(result.response) +print(result.response) diff --git a/pydantic_ai/_utils.py b/pydantic_ai/_utils.py index a2a0a9dc..3a2b5402 100644 --- a/pydantic_ai/_utils.py +++ b/pydantic_ai/_utils.py @@ -53,11 +53,16 @@ def is_model_like(type_: Any) -> bool: ) -class ObjectJsonSchema(TypedDict): - type: Literal['object'] - title: str - properties: dict[str, JsonSchemaValue] - required: list[str] +ObjectJsonSchema = TypedDict( + 'ObjectJsonSchema', + { + 'type': Literal['object'], + 'title': str, + 'properties': dict[str, JsonSchemaValue], + 'required': list[str], + '$defs': dict[str, Any], + }, +) def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema: diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index d4cf33c0..78bee3e7 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -21,7 +21,7 @@ class Agent(Generic[AgentDeps, ResultData]): __slots__ = ( '_model', 'result_schema', - '_allow_plain_message', + '_allow_plain_response', '_system_prompts', '_retrievers', '_default_retries', @@ -51,7 +51,7 @@ def __init__( response_schema_description, response_retries if response_retries is not None else retries, ) - self._allow_plain_message = self.result_schema is None or self.result_schema.allow_plain_message + self._allow_plain_response = self.result_schema is None or self.result_schema.allow_plain_message self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {r_.name: r_ for r_ in retrievers} @@ -98,10 +98,10 @@ async def run( messages.append(_messages.UserPrompt(user_prompt)) - functions: list[_models.AbstractRetrieverDefinition] = list(self._retrievers.values()) + functions: list[_models.AbstractToolDefinition] = list(self._retrievers.values()) if self.result_schema is not None: functions.append(self.result_schema) - agent_model = model_.agent_model(self._allow_plain_message, functions) + agent_model = model_.agent_model(self._allow_plain_response, functions) for retriever in self._retrievers.values(): retriever.reset() @@ -217,7 +217,7 @@ async def _handle_model_response( messages.append(llm_message) if llm_message.role == 'llm-response': # plain string response - if self._allow_plain_message: + if self._allow_plain_response: return _utils.Some(cast(ResultData, llm_message.content)) else: messages.append(_messages.PlainResponseForbidden()) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index cad03d29..b07206d9 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -19,8 +19,13 @@ class Model(ABC): """Abstract class for a model.""" @abstractmethod - def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel: - """Create an agent model.""" + def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel: + """Create an agent model. + + Args: + allow_plain_response: Whether plain text final response is permitted. + tools: The tools available to the agent. + """ raise NotImplementedError() @@ -47,8 +52,11 @@ def infer_model(model: Model | KnownModelName) -> Model: raise TypeError(f'Invalid model: {model}') -class AbstractRetrieverDefinition(Protocol): - """Abstract definition of a retriever/function/tool.""" +class AbstractToolDefinition(Protocol): + """Abstract definition of a function/tool. + + These are generally retrievers, but can also include the response function if one exists. + """ name: str description: str diff --git a/pydantic_ai/models/function.py b/pydantic_ai/models/function.py index 05a5305f..ed84c554 100644 --- a/pydantic_ai/models/function.py +++ b/pydantic_ai/models/function.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Protocol from ..messages import LLMMessage, Message -from . import AbstractRetrieverDefinition, AgentModel, Model +from . import AbstractToolDefinition, AgentModel, Model if TYPE_CHECKING: from .._utils import ObjectJsonSchema @@ -12,12 +12,12 @@ class FunctionDef(Protocol): def __call__( - self, messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription], / + self, messages: list[Message], allow_plain_response: bool, tools: dict[str, ToolDescription], / ) -> LLMMessage: ... @dataclass -class RetrieverDescription: +class ToolDescription: name: str description: str json_schema: ObjectJsonSchema @@ -30,19 +30,21 @@ class FunctionModel(Model): function: FunctionDef - def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel: - return TestAgentModel( + def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel: + return FunctionAgentModel( self.function, - allow_plain_message, - {r.name: RetrieverDescription(r.name, r.description, r.json_schema) for r in retrievers}, + allow_plain_response, + {r.name: ToolDescription(r.name, r.description, r.json_schema) for r in tools}, ) @dataclass -class TestAgentModel(AgentModel): +class FunctionAgentModel(AgentModel): + __test__ = False + function: FunctionDef - allow_plain_message: bool - retrievers: dict[str, RetrieverDescription] + allow_plain_response: bool + tools: dict[str, ToolDescription] async def request(self, messages: list[Message]) -> LLMMessage: - return self.function(messages, self.allow_plain_message, self.retrievers) + return self.function(messages, self.allow_plain_response, self.tools) diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 08c9ab84..1647d6cb 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -16,7 +16,7 @@ LLMResponse, Message, ) -from . import AbstractRetrieverDefinition, AgentModel, Model +from . import AbstractToolDefinition, AgentModel, Model class OpenAIModel(Model): @@ -26,12 +26,12 @@ def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client: self.model_name: ChatModel = model_name self.client = client or cached_async_client(api_key) - def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel: + def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel: return OpenAIAgentModel( self.client, self.model_name, - allow_plain_message, - [map_retriever_definition(t) for t in retrievers], + allow_plain_response, + [map_tool_definition(t) for t in tools], ) @@ -39,7 +39,7 @@ def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrie class OpenAIAgentModel(AgentModel): client: AsyncClient model_name: ChatModel - allow_plain_message: bool + allow_plain_response: bool tools: list[ChatCompletionToolParam] async def request(self, messages: list[Message]) -> LLMMessage: @@ -66,7 +66,7 @@ async def completions_create(self, messages: list[Message]) -> ChatCompletion: # standalone function to make it easier to override if not self.tools: tool_choice: Literal['none', 'required', 'auto'] = 'none' - elif not self.allow_plain_message: + elif not self.allow_plain_response: tool_choice = 'required' else: tool_choice = 'auto' @@ -87,7 +87,7 @@ def cached_async_client(api_key: str) -> AsyncClient: return AsyncClient(api_key=api_key) -def map_retriever_definition(f: AbstractRetrieverDefinition) -> ChatCompletionToolParam: +def map_tool_definition(f: AbstractToolDefinition) -> ChatCompletionToolParam: return { 'type': 'function', 'function': { diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py new file mode 100644 index 00000000..889bf7bc --- /dev/null +++ b/pydantic_ai/models/test.py @@ -0,0 +1,247 @@ +""" +Utilities for testing apps build with pydantic_ai, specifically by using a model based which calls +local functions. +""" + +from __future__ import annotations as _annotations + +import json +import re +import string +from dataclasses import dataclass +from typing import Any, Literal + +from .. import _utils +from ..messages import FunctionCall, LLMFunctionCalls, LLMMessage, LLMResponse, Message +from . import AbstractToolDefinition, AgentModel, Model + + +@dataclass +class TestModel(Model): + """ + A model specifically for testing purposes. + + This will (by default) call all retrievers in the agent model, then return a tool response if possible, + otherwise a plain response. + + How useful this function will be is unknown, it may be useless, it may require significant changes to be useful. + """ + + # NOTE: Avoid test discovery by pytest. + __test__ = False + + call_retrievers: list[str] | Literal['all'] = 'all' + custom_response_text: str | None = None + custom_response_args: Any | None = None + + def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel: + if self.call_retrievers == 'all': + retriever_calls = [(r.name, gen_retriever_args(r)) for r in tools if r.name != 'response'] + else: + lookup = {r.name: r for r in tools} + retriever_calls = [(name, gen_retriever_args(lookup[name])) for name in self.call_retrievers] + + if self.custom_response_text is not None: + if not allow_plain_response: + raise ValueError('Plain response not allowed, but `custom_response_text` is set.') + final_response: _utils.Either[str, str] = _utils.Either(left=self.custom_response_text) + elif self.custom_response_args is not None: + response_def = next((r for r in tools if r.name == 'response'), None) + if response_def is None: + raise ValueError('Custom response arguments provided, but no response tool found.') + final_response = _utils.Either(right=self.custom_response_args) + else: + if response_def := next((r for r in tools if r.name == 'response'), None): + final_response = _utils.Either(right=gen_retriever_args(response_def)) + else: + final_response = _utils.Either(left='Final response') + return TestAgentModel(retriever_calls, final_response) + + +@dataclass +class TestAgentModel(AgentModel): + # NOTE: Avoid test discovery by pytest. + __test__ = False + + retriever_calls: list[tuple[str, str]] + # left means the final response is plain text, right means it's a function call + final_response: _utils.Either[str, str] + step: int = 0 + + async def request(self, messages: list[Message]) -> LLMMessage: + if self.step == 0: + self.step += 1 + return LLMFunctionCalls( + calls=[ + FunctionCall(function_id=name, function_name=name, arguments=args) + for name, args in self.retriever_calls + ] + ) + elif self.step == 1: + self.step += 1 + if response_text := self.final_response.left: + return LLMResponse(content=response_text) + else: + response_args = self.final_response.right + return LLMFunctionCalls( + calls=[FunctionCall(function_id='response', function_name='response', arguments=response_args)] + ) + else: + raise ValueError('Invalid step') + + +def gen_retriever_args(tool_def: AbstractToolDefinition) -> str: + """Generate arguments for a retriever.""" + return _JsonSchemaTestData(tool_def.json_schema).generate_json() + + +_chars = string.ascii_letters + string.digits + string.punctuation + + +class _JsonSchemaTestData: + """ + Generate data that matches a JSON schema. + + This tries to generate the minimal viable data for the schema. + """ + + def __init__(self, schema: _utils.ObjectJsonSchema): + self.schema = schema + self.defs = schema.get('$defs', {}) + self.seed = 0 + + def generate(self) -> Any: + """Generate data for the JSON schema.""" + return self._gen_any(self.schema) # pyright: ignore[reportArgumentType] + + def generate_json(self) -> str: + return json.dumps(self.generate()) + + def _gen_any(self, schema: dict[str, Any]) -> Any: + """Generate data for any JSON Schema.""" + if const := schema.get('const'): + return const + elif enum := schema.get('enum'): + return enum[0] + elif examples := schema.get('examples'): + return examples[0] + + type_ = schema.get('type') + if type_ is None: + if ref := schema.get('$ref'): + key = re.sub(r'^#/\$defs/', '', ref) + js_def = self.defs[key] + return self._gen_any(js_def) + else: + # if there's no type or ref, we can't generate anything + return self._char() + + if type_ == 'object': + return self._object_gen(schema) + elif type_ == 'string': + return self._str_gen(schema) + elif type_ == 'integer': + return self._int_gen(schema) + elif type_ == 'number': + return float(self._int_gen(schema)) + elif type_ == 'boolean': + return self._bool_gen() + elif type_ == 'array': + return self._array_gen(schema) + else: + raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!') + + def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]: + """Generate data for a JSON Schema object.""" + required = set(schema.get('required', [])) + + data: dict[str, Any] = {} + if properties := schema.get('properties'): + for key, value in properties.items(): + if key in required: + data[key] = self._gen_any(value) + + if addition_props := schema.get('additionalProperties'): + add_prop_key = 'additionalProperty' + while add_prop_key in data: + add_prop_key += '_' + data[add_prop_key] = self._gen_any(addition_props) + + return data + + def _str_gen(self, schema: dict[str, Any]) -> str: + """Generate a string from a JSON Schema string.""" + min_len = schema.get('minLength') + if min_len is not None: + return self._char() * min_len + + if schema.get('maxLength') == 0: + return '' + else: + return self._char() + + def _int_gen(self, schema: dict[str, Any]) -> int: + """Generate an integer from a JSON Schema integer.""" + maximum = schema.get('maximum') + if maximum is None: + exc_max = schema.get('exclusiveMaximum') + if exc_max is not None: + maximum = exc_max - 1 + + minimum = schema.get('minimum') + if minimum is None: + exc_min = schema.get('exclusiveMinimum') + if exc_min is not None: + minimum = exc_min + 1 + + if minimum is not None and maximum is not None: + return minimum + self.seed % (maximum - minimum) + elif minimum is not None: + return minimum + self.seed + elif maximum is not None: + return maximum - self.seed + else: + return self.seed + + def _bool_gen(self) -> bool: + """Generate a boolean from a JSON Schema boolean.""" + return bool(self.seed % 2) + + def _array_gen(self, schema: dict[str, Any]) -> list[Any]: + """Generate an array from a JSON Schema array.""" + data: list[Any] = [] + unique_items = schema.get('uniqueItems') + if prefix_items := schema.get('prefixItems'): + for item in prefix_items: + if unique_items: + self.seed += 1 + data.append(self._gen_any(item)) + + items_schema = schema.get('items', {}) + min_items = schema.get('minItems', 0) + if min_items > len(data): + for _ in range(min_items - len(data)): + if unique_items: + self.seed += 1 + data.append(self._gen_any(items_schema)) + elif items_schema: + # if there is an `items` schema, add an item if minItems doesn't require it + # unless it would break `maxItems` rule + max_items = schema.get('maxItems') + if max_items is None or max_items > len(data): + if unique_items: + self.seed += 1 + data.append(self._gen_any(items_schema)) + + return data + + def _char(self) -> str: + """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az...""" + chars = len(_chars) + s = '' + rem = self.seed // chars + while rem > 0: + s += _chars[rem % chars] + rem //= chars + s += _chars[self.seed % chars] + return s diff --git a/tests/test_function_model.py b/tests/test_function_model.py index d5330582..9adbb4be 100644 --- a/tests/test_function_model.py +++ b/tests/test_function_model.py @@ -16,9 +16,11 @@ LLMMessage, LLMResponse, Message, + SystemPrompt, UserPrompt, ) -from pydantic_ai.models.function import FunctionModel, RetrieverDescription +from pydantic_ai.models.function import FunctionModel, ToolDescription +from pydantic_ai.models.test import TestModel if TYPE_CHECKING: @@ -28,7 +30,7 @@ def IsNow(*args: Any, **kwargs: Any) -> datetime: ... def return_last( - messages: list[Message], _allow_plain_message: bool, _retrievers: dict[str, RetrieverDescription] + messages: list[Message], _allow_plain_message: bool, _retrievers: dict[str, ToolDescription] ) -> LLMMessage: last = messages[-1] response = asdict(last) @@ -85,7 +87,7 @@ def test_simple(): def whether_model( - messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription] + messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription] ) -> LLMMessage: assert allow_plain_message assert retrievers.keys() == {'get_location', 'get_whether'} @@ -192,7 +194,7 @@ def test_whether(): def call_function_model( - messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription] + messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription] ) -> LLMMessage: last = messages[-1] if last.role == 'user': @@ -237,7 +239,7 @@ def test_var_args(): def call_retriever( - messages: list[Message], _allow_plain_message: bool, retrievers: dict[str, RetrieverDescription] + messages: list[Message], _allow_plain_message: bool, retrievers: dict[str, ToolDescription] ) -> LLMMessage: if len(messages) == 1: assert len(retrievers) == 1 @@ -289,7 +291,7 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str: def test_result_schema_tuple(): - def return_tuple(_: list[Message], __: bool, retrievers: dict[str, RetrieverDescription]) -> LLMMessage: + def return_tuple(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: assert len(retrievers) == 1 retriever_key = next(iter(retrievers.keys())) tuple_json = '{"response": ["foo", "bar"]}' @@ -308,7 +310,7 @@ class Foo(BaseModel): a: int b: str - def return_tuple(_: list[Message], __: bool, retrievers: dict[str, RetrieverDescription]) -> LLMMessage: + def return_tuple(_: list[Message], __: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: assert len(retrievers) == 1 retriever_key = next(iter(retrievers.keys())) tuple_json = '{"a": 1, "b": "foo"}' @@ -337,22 +339,22 @@ def test_model_arg(): @agent_all.retriever_context async def foo(_: CallContext[None], x: int) -> str: - return str(x * 2) + return str(x + 1) @agent_all.retriever_context(retries=3) def bar(_: CallContext[None], x: int) -> str: - return str(x * 3) + return str(x + 2) @agent_all.retriever_plain async def baz(x: int) -> str: - return str(x * 4) + return str(x + 3) @agent_all.retriever_plain(retries=1) def qux(x: int) -> str: - return str(x * 5) + return str(x + 4) @agent_all.system_prompt @@ -361,12 +363,35 @@ def spam() -> str: def test_register_all(): - def f( - messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription] - ) -> LLMMessage: + def f(messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: return LLMResponse( f'messages={len(messages)} allow_plain_message={allow_plain_message} retrievers={len(retrievers)}' ) result = agent_all.run_sync('Hello', model=FunctionModel(f)) assert result.response == snapshot('messages=2 allow_plain_message=True retrievers=4') + + +def test_call_all(): + result = agent_all.run_sync('Hello', model=TestModel()) + assert result.response == snapshot('Final response') + assert result.message_history == snapshot( + [ + SystemPrompt(content='foobar'), + UserPrompt(content='Hello', timestamp=IsNow()), + LLMFunctionCalls( + calls=[ + FunctionCall(function_id='foo', function_name='foo', arguments='{"x": 0}'), + FunctionCall(function_id='bar', function_name='bar', arguments='{"x": 0}'), + FunctionCall(function_id='baz', function_name='baz', arguments='{"x": 0}'), + FunctionCall(function_id='qux', function_name='qux', arguments='{"x": 0}'), + ], + timestamp=IsNow(), + ), + FunctionReturn(function_id='foo', function_name='foo', content='1', timestamp=IsNow()), + FunctionReturn(function_id='bar', function_name='bar', content='2', timestamp=IsNow()), + FunctionReturn(function_id='baz', function_name='baz', content='3', timestamp=IsNow()), + FunctionReturn(function_id='qux', function_name='qux', content='4', timestamp=IsNow()), + LLMResponse(content='Final response', timestamp=IsNow()), + ] + ) diff --git a/tests/test_testing.py b/tests/test_testing.py new file mode 100644 index 00000000..1c436315 --- /dev/null +++ b/tests/test_testing.py @@ -0,0 +1,76 @@ +""" +This module contains tests for the testing module. +""" + +from __future__ import annotations as _annotations + +from typing import Annotated, Any, Literal + +from annotated_types import Gt, Lt, MaxLen, MinLen +from inline_snapshot import snapshot +from pydantic import BaseModel + +from pydantic_ai.models.test import _JsonSchemaTestData # pyright: ignore[reportPrivateUsage] + + +def test_simple(): + class NestedModel(BaseModel): + foo: str + bar: int + + class TestModel(BaseModel): + my_str: str + my_str_long: Annotated[str, MinLen(10)] + my_str_short: Annotated[str, MaxLen(1)] + my_int: int + my_int_gt: Annotated[int, Gt(5)] + my_int_lt: Annotated[int, Lt(-5)] + my_float: float + my_float_gt: Annotated[float, Gt(5.0)] + my_float_lt: Annotated[float, Lt(-5.0)] + my_bool: bool + my_bytes: bytes + my_fixed_tuple: tuple[int, str] + my_var_tuple: tuple[int, ...] + my_list: list[str] + my_dict: dict[str, int] + my_set: set[str] + my_set_min_len: Annotated[set[str], MinLen(5)] + my_lit_int: Literal[1] + my_lit_ints: Literal[1, 2, 3] + my_lit_str: Literal['a'] + my_lit_strs: Literal['a', 'b', 'c'] + my_any: Any + nested: NestedModel + not_required: str = 'default' + + json_schema = TestModel.model_json_schema() + data = _JsonSchemaTestData(json_schema).generate() # pyright: ignore[reportArgumentType] + assert data == snapshot( + { + 'my_str': 'a', + 'my_str_long': 'aaaaaaaaaa', + 'my_str_short': 'a', + 'my_int': 0, + 'my_int_gt': 6, + 'my_int_lt': -6, + 'my_float': 0.0, + 'my_float_gt': 6.0, + 'my_float_lt': -6.0, + 'my_bool': False, + 'my_bytes': 'a', + 'my_fixed_tuple': [0, 'a'], + 'my_var_tuple': [0], + 'my_list': ['a'], + 'my_dict': {'additionalProperty': 0}, + 'my_set': ['b'], + 'my_set_min_len': ['c', 'd', 'e', 'f', 'g'], + 'my_lit_int': 1, + 'my_lit_ints': 1, + 'my_lit_str': 'a', + 'my_lit_strs': 'a', + 'my_any': 'g', + 'nested': {'foo': 'g', 'bar': 6}, + } + ) + TestModel.model_validate(data) From 3dd126af2168c114ca451f617c45a8f2967ec0a8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 13 Oct 2024 00:08:07 +0100 Subject: [PATCH 2/2] improve coverage --- pydantic_ai/_pydantic.py | 11 +++++- pydantic_ai/agent.py | 2 +- pydantic_ai/result.py | 4 +-- tests/test_function_model.py | 67 ++++++++++++++++++++++++++++++------ 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/pydantic_ai/_pydantic.py b/pydantic_ai/_pydantic.py index 8f041b05..6efa3c5e 100644 --- a/pydantic_ai/_pydantic.py +++ b/pydantic_ai/_pydantic.py @@ -107,7 +107,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P]) field_info, decorators, ) - td_schema['metadata'] = {'is_model_like': is_model_like(annotation)} + extra_metadata = {'is_model_like': is_model_like(annotation)} + if metadata := td_schema.get('metadata'): + metadata.update(extra_metadata) + else: + td_schema['metadata'] = extra_metadata if p.kind == Parameter.POSITIONAL_ONLY: positional_fields.append(field_name) elif p.kind == Parameter.VAR_POSITIONAL: @@ -130,6 +134,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P]) # PluggableSchemaValidator is api compat with SchemaValidator schema_validator = cast(SchemaValidator, schema_validator) json_schema = GenerateJsonSchema().generate(schema) + + # instead of passing `description` through in core_schema, we just add it here + if description: + json_schema = {'description': description} | json_schema + return FunctionSchema( description=description, validator=schema_validator, diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 78bee3e7..15b02598 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -51,7 +51,7 @@ def __init__( response_schema_description, response_retries if response_retries is not None else retries, ) - self._allow_plain_response = self.result_schema is None or self.result_schema.allow_plain_message + self._allow_plain_response = self.result_schema is None or self.result_schema.allow_plain_response self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {r_.name: r_ for r_ in retrievers} diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 23f842df..1a7f4972 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -61,7 +61,7 @@ class ResultSchema(Generic[ResultData]): description: str type_adapter: TypeAdapter[Any] json_schema: _utils.ObjectJsonSchema - allow_plain_message: bool + allow_plain_response: bool outer_typed_dict: bool max_retries: int _current_retry: int = 0 @@ -86,7 +86,7 @@ def build(cls, response_type: type[ResultData], name: str, description: str, ret description=description, type_adapter=type_adapter, json_schema=_utils.check_object_json_schema(type_adapter.json_schema()), - allow_plain_message=_utils.allow_plain_str(response_type), + allow_plain_response=_utils.allow_plain_str(response_type), outer_typed_dict=outer_typed_dict, max_retries=retries, ) diff --git a/tests/test_function_model.py b/tests/test_function_model.py index 9adbb4be..abb724f0 100644 --- a/tests/test_function_model.py +++ b/tests/test_function_model.py @@ -30,7 +30,7 @@ def IsNow(*args: Any, **kwargs: Any) -> datetime: ... def return_last( - messages: list[Message], _allow_plain_message: bool, _retrievers: dict[str, ToolDescription] + messages: list[Message], _allow_plain_response: bool, _retrievers: dict[str, ToolDescription] ) -> LLMMessage: last = messages[-1] response = asdict(last) @@ -87,9 +87,9 @@ def test_simple(): def whether_model( - messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription] -) -> LLMMessage: - assert allow_plain_message + messages: list[Message], allow_plain_response: bool, retrievers: dict[str, ToolDescription] +) -> LLMMessage: # pragma: no cover + assert allow_plain_response assert retrievers.keys() == {'get_location', 'get_whether'} last = messages[-1] if last.role == 'user': @@ -194,8 +194,8 @@ def test_whether(): def call_function_model( - messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription] -) -> LLMMessage: + messages: list[Message], _allow_plain_response: bool, _tools: dict[str, ToolDescription] +) -> LLMMessage: # pragma: no cover last = messages[-1] if last.role == 'user': if last.content.startswith('{'): @@ -239,7 +239,7 @@ def test_var_args(): def call_retriever( - messages: list[Message], _allow_plain_message: bool, retrievers: dict[str, ToolDescription] + messages: list[Message], _allow_plain_response: bool, retrievers: dict[str, ToolDescription] ) -> LLMMessage: if len(messages) == 1: assert len(retrievers) == 1 @@ -343,7 +343,7 @@ async def foo(_: CallContext[None], x: int) -> str: @agent_all.retriever_context(retries=3) -def bar(_: CallContext[None], x: int) -> str: +def bar(ctx, x: int) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] return str(x + 2) @@ -357,19 +357,24 @@ def qux(x: int) -> str: return str(x + 4) +@agent_all.retriever_plain # pyright: ignore[reportUnknownArgumentType] +def quz(x) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] + return str(x) # pyright: ignore[reportUnknownArgumentType] + + @agent_all.system_prompt def spam() -> str: return 'foobar' def test_register_all(): - def f(messages: list[Message], allow_plain_message: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + def f(messages: list[Message], allow_plain_response: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: return LLMResponse( - f'messages={len(messages)} allow_plain_message={allow_plain_message} retrievers={len(retrievers)}' + f'messages={len(messages)} allow_plain_response={allow_plain_response} retrievers={len(retrievers)}' ) result = agent_all.run_sync('Hello', model=FunctionModel(f)) - assert result.response == snapshot('messages=2 allow_plain_message=True retrievers=4') + assert result.response == snapshot('messages=2 allow_plain_response=True retrievers=5') def test_call_all(): @@ -385,6 +390,7 @@ def test_call_all(): FunctionCall(function_id='bar', function_name='bar', arguments='{"x": 0}'), FunctionCall(function_id='baz', function_name='baz', arguments='{"x": 0}'), FunctionCall(function_id='qux', function_name='qux', arguments='{"x": 0}'), + FunctionCall(function_id='quz', function_name='quz', arguments='{"x": "a"}'), ], timestamp=IsNow(), ), @@ -392,6 +398,45 @@ def test_call_all(): FunctionReturn(function_id='bar', function_name='bar', content='2', timestamp=IsNow()), FunctionReturn(function_id='baz', function_name='baz', content='3', timestamp=IsNow()), FunctionReturn(function_id='qux', function_name='qux', content='4', timestamp=IsNow()), + FunctionReturn(function_id='quz', function_name='quz', content='a', timestamp=IsNow()), LLMResponse(content='Final response', timestamp=IsNow()), ] ) + + +async def do_foobar(foo: int, bar: str) -> str: + """ + Do foobar stuff, a lot. + + Args: + foo: The foo thing. + bar: The bar thing. + """ + return f'{foo} {bar}' + + +def test_docstring(): + def f(_messages: list[Message], _allow_plain_response: bool, retrievers: dict[str, ToolDescription]) -> LLMMessage: + assert len(retrievers) == 1 + r = next(iter(retrievers.values())) + return LLMResponse(json.dumps(r.json_schema)) + + agent = Agent(FunctionModel(f), deps=None) + agent.retriever_plain(do_foobar) + + result = agent.run_sync('Hello') + json_schema = json.loads(result.response) + assert json_schema == snapshot( + { + 'description': 'Do foobar stuff, a lot.', + 'additionalProperties': False, + 'properties': { + 'foo': {'description': 'The foo thing.', 'title': 'Foo', 'type': 'integer'}, + 'bar': {'description': 'The bar thing.', 'title': 'Bar', 'type': 'string'}, + }, + 'required': ['foo', 'bar'], + 'type': 'object', + } + ) + # description should be the first key + assert next(iter(json_schema)) == 'description'