Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing generate args #3

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ test:
testcov: test
@echo "building coverage html"
@uv run coverage html --show-contexts
@uv run coverage report

.PHONY: all
all: format lint typecheck test
4 changes: 1 addition & 3 deletions demos/parse_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from devtools import debug
from pydantic import BaseModel

from pydantic_ai import Agent
Expand All @@ -11,7 +10,6 @@ class MyModel(BaseModel):

agent = Agent('openai:gpt-4o', response_type=MyModel, deps=None)

# debug(agent.result_schema.json_schema)
result = agent.run_sync('The windy city in the US of A.')

debug(result.response)
print(result.response)
11 changes: 10 additions & 1 deletion pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P])
field_info,
decorators,
)
td_schema['metadata'] = {'is_model_like': is_model_like(annotation)}
extra_metadata = {'is_model_like': is_model_like(annotation)}
if metadata := td_schema.get('metadata'):
metadata.update(extra_metadata)
else:
td_schema['metadata'] = extra_metadata
if p.kind == Parameter.POSITIONAL_ONLY:
positional_fields.append(field_name)
elif p.kind == Parameter.VAR_POSITIONAL:
Expand All @@ -130,6 +134,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P])
# PluggableSchemaValidator is api compat with SchemaValidator
schema_validator = cast(SchemaValidator, schema_validator)
json_schema = GenerateJsonSchema().generate(schema)

# instead of passing `description` through in core_schema, we just add it here
if description:
json_schema = {'description': description} | json_schema

return FunctionSchema(
description=description,
validator=schema_validator,
Expand Down
15 changes: 10 additions & 5 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ def is_model_like(type_: Any) -> bool:
)


class ObjectJsonSchema(TypedDict):
type: Literal['object']
title: str
properties: dict[str, JsonSchemaValue]
required: list[str]
ObjectJsonSchema = TypedDict(
'ObjectJsonSchema',
{
'type': Literal['object'],
'title': str,
'properties': dict[str, JsonSchemaValue],
'required': list[str],
'$defs': dict[str, Any],
},
)


def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Agent(Generic[AgentDeps, ResultData]):
__slots__ = (
'_model',
'result_schema',
'_allow_plain_message',
'_allow_plain_response',
'_system_prompts',
'_retrievers',
'_default_retries',
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
response_schema_description,
response_retries if response_retries is not None else retries,
)
self._allow_plain_message = self.result_schema is None or self.result_schema.allow_plain_message
self._allow_plain_response = self.result_schema is None or self.result_schema.allow_plain_response

self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {r_.name: r_ for r_ in retrievers}
Expand Down Expand Up @@ -98,10 +98,10 @@ async def run(

messages.append(_messages.UserPrompt(user_prompt))

functions: list[_models.AbstractRetrieverDefinition] = list(self._retrievers.values())
functions: list[_models.AbstractToolDefinition] = list(self._retrievers.values())
if self.result_schema is not None:
functions.append(self.result_schema)
agent_model = model_.agent_model(self._allow_plain_message, functions)
agent_model = model_.agent_model(self._allow_plain_response, functions)

for retriever in self._retrievers.values():
retriever.reset()
Expand Down Expand Up @@ -217,7 +217,7 @@ async def _handle_model_response(
messages.append(llm_message)
if llm_message.role == 'llm-response':
# plain string response
if self._allow_plain_message:
if self._allow_plain_response:
return _utils.Some(cast(ResultData, llm_message.content))
else:
messages.append(_messages.PlainResponseForbidden())
Expand Down
16 changes: 12 additions & 4 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class Model(ABC):
"""Abstract class for a model."""

@abstractmethod
def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
"""Create an agent model."""
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
"""Create an agent model.

Args:
allow_plain_response: Whether plain text final response is permitted.
tools: The tools available to the agent.
"""
raise NotImplementedError()


Expand All @@ -47,8 +52,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
raise TypeError(f'Invalid model: {model}')


class AbstractRetrieverDefinition(Protocol):
"""Abstract definition of a retriever/function/tool."""
class AbstractToolDefinition(Protocol):
"""Abstract definition of a function/tool.

These are generally retrievers, but can also include the response function if one exists.
"""

name: str
description: str
Expand Down
24 changes: 13 additions & 11 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from typing import TYPE_CHECKING, Protocol

from ..messages import LLMMessage, Message
from . import AbstractRetrieverDefinition, AgentModel, Model
from . import AbstractToolDefinition, AgentModel, Model

if TYPE_CHECKING:
from .._utils import ObjectJsonSchema


class FunctionDef(Protocol):
def __call__(
self, messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription], /
self, messages: list[Message], allow_plain_response: bool, tools: dict[str, ToolDescription], /
) -> LLMMessage: ...


@dataclass
class RetrieverDescription:
class ToolDescription:
name: str
description: str
json_schema: ObjectJsonSchema
Expand All @@ -30,19 +30,21 @@ class FunctionModel(Model):

function: FunctionDef

def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
return TestAgentModel(
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
return FunctionAgentModel(
self.function,
allow_plain_message,
{r.name: RetrieverDescription(r.name, r.description, r.json_schema) for r in retrievers},
allow_plain_response,
{r.name: ToolDescription(r.name, r.description, r.json_schema) for r in tools},
)


@dataclass
class TestAgentModel(AgentModel):
class FunctionAgentModel(AgentModel):
__test__ = False

function: FunctionDef
allow_plain_message: bool
retrievers: dict[str, RetrieverDescription]
allow_plain_response: bool
tools: dict[str, ToolDescription]

async def request(self, messages: list[Message]) -> LLMMessage:
return self.function(messages, self.allow_plain_message, self.retrievers)
return self.function(messages, self.allow_plain_response, self.tools)
14 changes: 7 additions & 7 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LLMResponse,
Message,
)
from . import AbstractRetrieverDefinition, AgentModel, Model
from . import AbstractToolDefinition, AgentModel, Model


class OpenAIModel(Model):
Expand All @@ -26,20 +26,20 @@ def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client:
self.model_name: ChatModel = model_name
self.client = client or cached_async_client(api_key)

def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
return OpenAIAgentModel(
self.client,
self.model_name,
allow_plain_message,
[map_retriever_definition(t) for t in retrievers],
allow_plain_response,
[map_tool_definition(t) for t in tools],
)


@dataclass
class OpenAIAgentModel(AgentModel):
client: AsyncClient
model_name: ChatModel
allow_plain_message: bool
allow_plain_response: bool
tools: list[ChatCompletionToolParam]

async def request(self, messages: list[Message]) -> LLMMessage:
Expand All @@ -66,7 +66,7 @@ async def completions_create(self, messages: list[Message]) -> ChatCompletion:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] = 'none'
elif not self.allow_plain_message:
elif not self.allow_plain_response:
tool_choice = 'required'
else:
tool_choice = 'auto'
Expand All @@ -87,7 +87,7 @@ def cached_async_client(api_key: str) -> AsyncClient:
return AsyncClient(api_key=api_key)


def map_retriever_definition(f: AbstractRetrieverDefinition) -> ChatCompletionToolParam:
def map_tool_definition(f: AbstractToolDefinition) -> ChatCompletionToolParam:
return {
'type': 'function',
'function': {
Expand Down
Loading