Skip to content

Commit

Permalink
rename tools -> retievers
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 12, 2024
1 parent c5d5f24 commit ff834b5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def run(

messages.append(_messages.UserPrompt(user_prompt))

functions: list[_models.AbstractToolDefinition] = list(self._retrievers.values())
functions: list[_models.AbstractRetrieverDefinition] = list(self._retrievers.values())
if self.result_schema is not None:
functions.append(self.result_schema)
agent_model = model_.agent_model(self._allow_plain_message, functions)
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class LLMResponse:

@dataclass
class FunctionCall:
"""
Either a retriever/tool call or structure response from the agent.
"""

function_id: str
function_name: str
arguments: str
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Model(ABC):
"""Abstract class for a model."""

@abstractmethod
def agent_model(self, allow_plain_message: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
"""Create an agent model."""
raise NotImplementedError()

Expand Down Expand Up @@ -47,8 +47,8 @@ def infer_model(model: Model | KnownModelName) -> Model:
raise TypeError(f'Invalid model: {model}')


class AbstractToolDefinition(Protocol):
"""Abstract definition of a function/tool."""
class AbstractRetrieverDefinition(Protocol):
"""Abstract definition of a retriever/function/tool."""

name: str
description: str
Expand Down
18 changes: 11 additions & 7 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from typing import TYPE_CHECKING, Protocol

from ..messages import LLMMessage, Message
from . import AbstractToolDefinition, AgentModel, Model
from . import AbstractRetrieverDefinition, AgentModel, Model

if TYPE_CHECKING:
from .._utils import ObjectJsonSchema


class FunctionDef(Protocol):
def __call__(self, messages: list[Message], allow_plain_message: bool, tools: dict[str, Tool], /) -> LLMMessage: ...
def __call__(
self, messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription], /
) -> LLMMessage: ...


@dataclass
class Tool:
class RetrieverDescription:
name: str
description: str
json_schema: ObjectJsonSchema
Expand All @@ -28,17 +30,19 @@ class FunctionModel(Model):

function: FunctionDef

def agent_model(self, allow_plain_message: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
return TestAgentModel(
self.function, allow_plain_message, {t.name: Tool(t.name, t.description, t.json_schema) for t in tools}
self.function,
allow_plain_message,
{r.name: RetrieverDescription(r.name, r.description, r.json_schema) for r in retrievers},
)


@dataclass
class TestAgentModel(AgentModel):
function: FunctionDef
allow_plain_message: bool
tools: dict[str, Tool]
retrievers: dict[str, RetrieverDescription]

async def request(self, messages: list[Message]) -> LLMMessage:
return self.function(messages, self.allow_plain_message, self.tools)
return self.function(messages, self.allow_plain_message, self.retrievers)
8 changes: 4 additions & 4 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LLMResponse,
Message,
)
from . import AbstractToolDefinition, AgentModel, Model
from . import AbstractRetrieverDefinition, AgentModel, Model


class OpenAIModel(Model):
Expand All @@ -26,12 +26,12 @@ def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client:
self.model_name: ChatModel = model_name
self.client = client or cached_async_client(api_key)

def agent_model(self, allow_plain_message: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
return OpenAIAgentModel(
self.client,
self.model_name,
allow_plain_message,
[map_tool_definition(t) for t in tools],
[map_retriever_definition(t) for t in retrievers],
)


Expand Down Expand Up @@ -87,7 +87,7 @@ def cached_async_client(api_key: str) -> AsyncClient:
return AsyncClient(api_key=api_key)


def map_tool_definition(f: AbstractToolDefinition) -> ChatCompletionToolParam:
def map_retriever_definition(f: AbstractRetrieverDefinition) -> ChatCompletionToolParam:
return {
'type': 'function',
'function': {
Expand Down
56 changes: 36 additions & 20 deletions tests/test_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Message,
UserPrompt,
)
from pydantic_ai.models.function import FunctionModel, Tool
from pydantic_ai.models.function import FunctionModel, RetrieverDescription

if TYPE_CHECKING:

Expand All @@ -27,7 +27,9 @@ def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
from dirty_equals import IsNow


def return_last(messages: list[Message], _allow_plain_message: bool, _tools: dict[str, Tool]) -> LLMMessage:
def return_last(
messages: list[Message], _allow_plain_message: bool, _retrievers: dict[str, RetrieverDescription]
) -> LLMMessage:
last = messages[-1]
response = asdict(last)
response.pop('timestamp', None)
Expand Down Expand Up @@ -82,9 +84,11 @@ def test_simple():
)


def whether_model(messages: list[Message], allow_plain_message: bool, tools: dict[str, Tool]) -> LLMMessage:
def whether_model(
messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription]
) -> LLMMessage:
assert allow_plain_message
assert tools.keys() == {'get_location', 'get_whether'}
assert retrievers.keys() == {'get_location', 'get_whether'}
last = messages[-1]
if last.role == 'user':
return LLMFunctionCalls(
Expand Down Expand Up @@ -187,7 +191,9 @@ def test_whether():
assert result.response == 'Sunny in Ipswich'


def call_function_model(messages: list[Message], allow_plain_message: bool, tools: dict[str, Tool]) -> LLMMessage:
def call_function_model(
messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription]
) -> LLMMessage:
last = messages[-1]
if last.role == 'user':
if last.content.startswith('{'):
Expand Down Expand Up @@ -230,11 +236,13 @@ def test_var_args():
)


def call_retriever(messages: list[Message], _allow_plain_message: bool, tools: dict[str, Tool]) -> LLMMessage:
def call_retriever(
messages: list[Message], _allow_plain_message: bool, retrievers: dict[str, RetrieverDescription]
) -> LLMMessage:
if len(messages) == 1:
assert len(tools) == 1
tool_id = next(iter(tools.keys()))
return LLMFunctionCalls(calls=[FunctionCall(function_id='1', function_name=tool_id, arguments='{}')])
assert len(retrievers) == 1
retriever_id = next(iter(retrievers.keys()))
return LLMFunctionCalls(calls=[FunctionCall(function_id='1', function_name=retriever_id, arguments='{}')])
else:
return LLMResponse('final response')

Expand Down Expand Up @@ -281,11 +289,13 @@ def get_check_foobar(ctx: CallContext[tuple[str, str]]) -> str:


def test_result_schema_tuple():
def return_tuple(_: list[Message], __: bool, tools: dict[str, Tool]) -> LLMMessage:
assert len(tools) == 1
tool_id = next(iter(tools.keys()))
def return_tuple(_: list[Message], __: bool, retrievers: dict[str, RetrieverDescription]) -> LLMMessage:
assert len(retrievers) == 1
retriever_key = next(iter(retrievers.keys()))
tuple_json = '{"response": ["foo", "bar"]}'
return LLMFunctionCalls(calls=[FunctionCall(function_id='1', function_name=tool_id, arguments=tuple_json)])
return LLMFunctionCalls(
calls=[FunctionCall(function_id='1', function_name=retriever_key, arguments=tuple_json)]
)

agent = Agent(FunctionModel(return_tuple), deps=None, response_type=tuple[str, str])

Expand All @@ -298,11 +308,13 @@ class Foo(BaseModel):
a: int
b: str

def return_tuple(_: list[Message], __: bool, tools: dict[str, Tool]) -> LLMMessage:
assert len(tools) == 1
tool_id = next(iter(tools.keys()))
def return_tuple(_: list[Message], __: bool, retrievers: dict[str, RetrieverDescription]) -> LLMMessage:
assert len(retrievers) == 1
retriever_key = next(iter(retrievers.keys()))
tuple_json = '{"a": 1, "b": "foo"}'
return LLMFunctionCalls(calls=[FunctionCall(function_id='1', function_name=tool_id, arguments=tuple_json)])
return LLMFunctionCalls(
calls=[FunctionCall(function_id='1', function_name=retriever_key, arguments=tuple_json)]
)

agent = Agent(FunctionModel(return_tuple), deps=None, response_type=Foo)

Expand Down Expand Up @@ -349,8 +361,12 @@ def spam() -> str:


def test_register_all():
def f(messages: list[Message], allow_plain_message: bool, tools: dict[str, Tool]) -> LLMMessage:
return LLMResponse(f'messages={len(messages)} allow_plain_message={allow_plain_message} tools={len(tools)}')
def f(
messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription]
) -> LLMMessage:
return LLMResponse(
f'messages={len(messages)} allow_plain_message={allow_plain_message} retrievers={len(retrievers)}'
)

result = agent_all.run_sync('Hello', model=FunctionModel(f))
assert result.response == snapshot('messages=2 allow_plain_message=True tools=4')
assert result.response == snapshot('messages=2 allow_plain_message=True retrievers=4')

0 comments on commit ff834b5

Please sign in to comment.