Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better deps typing #2

Merged
merged 4 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:

- run: make lint
- run: make typecheck
- run: make test-mypy
- run: make typecheck-mypy

test:
name: test on ${{ matrix.python-version }}
Expand Down Expand Up @@ -77,7 +77,7 @@ jobs:
merge-multiple: true
path: coverage

- run: pip install coverage[toml]
- run: pip install coverage[toml] --break-system-packages
- run: coverage combine coverage
- run: coverage xml
# - uses: codecov/codecov-action@v4
Expand Down
15 changes: 9 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ lint:
uv run ruff format --check $(sources)
uv run ruff check $(sources)

.PHONY: typecheck # Run static type checking
typecheck:
.PHONY: typecheck-pyright
typecheck-pyright:
uv run pyright

.PHONY: typecheck-mypy
typecheck-mypy:
uv run mypy --strict tests/typed_agent.py

.PHONY: typecheck # Run static type checking
typecheck: typecheck-pyright

.PHONY: test # Run tests and collect coverage data
test:
uv run coverage run -m pytest

.PHONY: test-mypy # Run type tests with mypy
test-mypy:
uv run mypy --strict tests/typed_agent.py

.PHONY: testcov # Run tests and generate a coverage report
testcov: test
@echo "building coverage html"
Expand Down
2 changes: 1 addition & 1 deletion demos/parse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class MyModel(BaseModel):
country: str


agent = Agent('openai:gpt-4o', response_type=MyModel)
agent = Agent('openai:gpt-4o', response_type=MyModel, deps=None)

# debug(agent.result_schema.json_schema)
result = agent.run_sync('The windy city in the US of A.')
Expand Down
2 changes: 1 addition & 1 deletion demos/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic_ai import Agent

weather_agent = Agent('openai:gpt-4o')
weather_agent: Agent[None, str] = Agent('openai:gpt-4o')


@weather_agent.system_prompt
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
if field_info.description is None:
field_info.description = field_descriptions.get(field_name)

fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # type: ignore[reportPrivateUsage]
fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # pyright: ignore[reportPrivateUsage]
field_name,
field_info,
decorators,
Expand Down
16 changes: 8 additions & 8 deletions pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,38 @@
from typing import Callable, Generic, Union

from . import _utils, retrievers as _r
from .retrievers import AgentDependencies
from .retrievers import AgentDeps

# This is basically a function that may or maybe not take `CallInfo` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDependencies]`
SystemPromptFunc = Union[
Callable[[_r.CallContext[AgentDependencies]], str],
Callable[[_r.CallContext[AgentDependencies]], Awaitable[str]],
Callable[[_r.CallContext[AgentDeps]], str],
Callable[[_r.CallContext[AgentDeps]], Awaitable[str]],
Callable[[], str],
Callable[[], Awaitable[str]],
]


@dataclass
class SystemPromptRunner(Generic[AgentDependencies]):
function: SystemPromptFunc[AgentDependencies]
class SystemPromptRunner(Generic[AgentDeps]):
function: SystemPromptFunc[AgentDeps]
takes_ctx: bool = False
is_async: bool = False

def __post_init__(self):
self.takes_ctx = len(inspect.signature(self.function).parameters) > 0
self.is_async = inspect.iscoroutinefunction(self.function)

async def run(self, deps: AgentDependencies) -> str:
async def run(self, deps: AgentDeps) -> str:
if self.takes_ctx:
args = (_r.CallContext(deps, 0),)
else:
args = ()

if self.is_async:
return await self.function(*args) # type: ignore[reportGeneralTypeIssues]
return await self.function(*args) # pyright: ignore[reportGeneralTypeIssues,reportUnknownVariableType]
else:
return await _utils.run_in_executor(
self.function, # type: ignore[reportArgumentType]
self.function, # pyright: ignore[reportArgumentType,reportReturnType]
*args,
)
92 changes: 57 additions & 35 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,43 @@

import asyncio
from collections.abc import Awaitable, Sequence
from typing import Any, Callable, Generic, Literal, Union, cast, overload
from copy import copy
from typing import Any, Callable, Generic, Literal, cast, overload

from typing_extensions import assert_never

from . import _system_prompt, _utils, messages as _messages, models as _models, result as _result, retrievers as _r
from .result import ResultData
from .retrievers import AgentDependencies
from .retrievers import AgentDeps

__all__ = ('Agent',)


KnownModelName = Literal['openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo']

SysPromptContext = Callable[[_r.CallContext[AgentDependencies]], Union[str, Awaitable[str]]]
SysPromptPlain = Callable[[], Union[str, Awaitable[str]]]


class Agent(Generic[ResultData, AgentDependencies]):
class Agent(Generic[AgentDeps, ResultData]):
"""Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM."""

# slots mostly for my sanity — knowing what attributes are available
__slots__ = (
'_model',
'result_schema',
'_allow_plain_message',
'_system_prompts',
'_retrievers',
'_default_retries',
'_system_prompt_functions',
'_default_deps',
)

def __init__(
self,
model: _models.Model | KnownModelName | None = None,
response_type: type[_result.ResultData] = str,
*,
system_prompt: str | Sequence[str] = (),
retrievers: Sequence[_r.Retriever[AgentDependencies, Any]] = (),
deps: AgentDependencies = None,
retrievers: Sequence[_r.Retriever[AgentDeps, Any]] = (),
# type here looks odd, but it's required os you can avoid "partially unknown" type errors with `deps=None`
deps: AgentDeps | tuple[()] = (),
retries: int = 1,
response_schema_name: str = 'final_response',
response_schema_description: str = 'The final response',
Expand All @@ -46,26 +55,36 @@ def __init__(
self._allow_plain_message = self.result_schema is None or self.result_schema.allow_plain_message

self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._retrievers: dict[str, _r.Retriever[AgentDependencies, Any]] = {r_.name: r_ for r_ in retrievers}
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {r_.name: r_ for r_ in retrievers}
if self.result_schema and self.result_schema.name in self._retrievers:
raise ValueError(f'Retriever name conflicts with response schema: {self.result_schema.name!r}')
self._deps = deps
self._default_deps = cast(AgentDeps, None if deps == () else deps)
self._default_retries = retries
self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDependencies]] = []
self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = []

def with_deps(self, deps: AgentDeps) -> Agent[AgentDeps, ResultData]:
"""Return a new agent with the given dependencies."""
agent: Agent[AgentDeps, ResultData] = Agent.__new__(Agent)
for attr in self.__slots__:
setattr(agent, attr, copy(getattr(self, attr)))
agent._default_deps = deps
return agent

async def run(
self,
user_prompt: str,
*,
message_history: list[_messages.Message] | None = None,
model: _models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> _result.RunResult[_result.ResultData]:
"""Run the agent with a user prompt in async mode.

Args:
user_prompt: User input to start/continue the conversation.
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.

Returns:
The result of the run.
Expand All @@ -77,11 +96,14 @@ async def run(
else:
raise RuntimeError('`model` must be set either when creating the agent or when calling it.')

if deps is None:
deps = self._default_deps

if message_history is not None:
# shallow copy messages
messages = message_history.copy()
else:
messages = await self._init_messages()
messages = await self._init_messages(deps)

messages.append(_messages.UserPrompt(user_prompt))

Expand All @@ -95,7 +117,7 @@ async def run(

while True:
llm_message = await agent_model.request(messages)
opt_result = await self._handle_model_response(messages, llm_message)
opt_result = await self._handle_model_response(messages, llm_message, deps)
if opt_result is not None:
return _result.RunResult(opt_result.value, messages, cost=_result.Cost(0))

Expand All @@ -105,6 +127,7 @@ def run_sync(
*,
message_history: list[_messages.Message] | None = None,
model: _models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> _result.RunResult[_result.ResultData]:
"""Run the agent with a user prompt synchronously.

Expand All @@ -114,42 +137,41 @@ def run_sync(
user_prompt: User input to start/continue the conversation.
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.

Returns:
The result of the run.
"""
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model))
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))

async def stream(self, user_prompt: str) -> _result.RunStreamResult[_result.ResultData]:
"""Run the agent with a user prompt asynchronously and stream the results."""
raise NotImplementedError()

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDependencies]
) -> _system_prompt.SystemPromptFunc[AgentDependencies]:
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
"""Decorator to register a system prompt function that takes `CallContext` as it's only argument."""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func

@overload
def retriever_context(
self, func: _r.RetrieverContextFunc[AgentDependencies, _r.P], /
) -> _r.Retriever[AgentDependencies, _r.P]: ...
def retriever_context(self, func: _r.RetrieverContextFunc[AgentDeps, _r.P], /) -> _r.Retriever[AgentDeps, _r.P]: ...

@overload
def retriever_context(
self, /, *, retries: int | None = None
) -> Callable[[_r.RetrieverContextFunc[AgentDependencies, _r.P]], _r.Retriever[AgentDependencies, _r.P]]: ...
) -> Callable[[_r.RetrieverContextFunc[AgentDeps, _r.P]], _r.Retriever[AgentDeps, _r.P]]: ...

def retriever_context(
self, func: _r.RetrieverContextFunc[AgentDependencies, _r.P] | None = None, /, *, retries: int | None = None
self, func: _r.RetrieverContextFunc[AgentDeps, _r.P] | None = None, /, *, retries: int | None = None
) -> Any:
"""Decorator to register a retriever function."""
if func is None:

def retriever_decorator(
func_: _r.RetrieverContextFunc[AgentDependencies, _r.P],
) -> _r.Retriever[AgentDependencies, _r.P]:
func_: _r.RetrieverContextFunc[AgentDeps, _r.P],
) -> _r.Retriever[AgentDeps, _r.P]:
# noinspection PyTypeChecker
return self._register_retriever(func_, True, retries)

Expand All @@ -158,18 +180,18 @@ def retriever_decorator(
return self._register_retriever(func, True, retries)

@overload
def retriever_plain(self, func: _r.RetrieverPlainFunc[_r.P], /) -> _r.Retriever[AgentDependencies, _r.P]: ...
def retriever_plain(self, func: _r.RetrieverPlainFunc[_r.P], /) -> _r.Retriever[AgentDeps, _r.P]: ...

@overload
def retriever_plain(
self, /, *, retries: int | None = None
) -> Callable[[_r.RetrieverPlainFunc[_r.P]], _r.Retriever[AgentDependencies, _r.P]]: ...
) -> Callable[[_r.RetrieverPlainFunc[_r.P]], _r.Retriever[AgentDeps, _r.P]]: ...

def retriever_plain(self, func: _r.RetrieverPlainFunc[_r.P] | None = None, /, *, retries: int | None = None) -> Any:
"""Decorator to register a retriever function."""
if func is None:

def retriever_decorator(func_: _r.RetrieverPlainFunc[_r.P]) -> _r.Retriever[AgentDependencies, _r.P]:
def retriever_decorator(func_: _r.RetrieverPlainFunc[_r.P]) -> _r.Retriever[AgentDeps, _r.P]:
# noinspection PyTypeChecker
return self._register_retriever(func_, False, retries)

Expand All @@ -178,11 +200,11 @@ def retriever_decorator(func_: _r.RetrieverPlainFunc[_r.P]) -> _r.Retriever[Agen
return self._register_retriever(func, False, retries)

def _register_retriever(
self, func: _r.RetrieverEitherFunc[AgentDependencies, _r.P], takes_ctx: bool, retries: int | None
) -> _r.Retriever[AgentDependencies, _r.P]:
self, func: _r.RetrieverEitherFunc[AgentDeps, _r.P], takes_ctx: bool, retries: int | None
) -> _r.Retriever[AgentDeps, _r.P]:
"""Private utility to register a retriever function."""
retries_ = retries if retries is not None else self._default_retries
retriever = _r.Retriever[AgentDependencies, _r.P](func, takes_ctx, retries_)
retriever = _r.Retriever[AgentDeps, _r.P](func, takes_ctx, retries_)

if self.result_schema and self.result_schema.name == retriever.name:
raise ValueError(f'Retriever name conflicts with response schema name: {retriever.name!r}')
Expand All @@ -194,7 +216,7 @@ def _register_retriever(
return retriever

async def _handle_model_response(
self, messages: list[_messages.Message], llm_message: _messages.LLMMessage
self, messages: list[_messages.Message], llm_message: _messages.LLMMessage, deps: AgentDeps
) -> _utils.Option[ResultData]:
"""Process a single response from the model.

Expand Down Expand Up @@ -227,15 +249,15 @@ async def _handle_model_response(
if retriever is None:
# TODO return message?
raise ValueError(f'Unknown function name: {call.function_name!r}')
coros.append(retriever.run(self._deps, call))
coros.append(retriever.run(deps, call))
messages += await asyncio.gather(*coros)
else:
assert_never(llm_message)

async def _init_messages(self) -> list[_messages.Message]:
async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
"""Build the initial messages for the conversation."""
messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
for sys_prompt_runner in self._system_prompt_functions:
prompt = await sys_prompt_runner.run(self._deps)
prompt = await sys_prompt_runner.run(deps)
messages.append(_messages.SystemPrompt(prompt))
return messages
2 changes: 1 addition & 1 deletion pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
elif model.startswith('openai:'):
from .openai import OpenAIModel

return OpenAIModel(model[7:]) # type: ignore[reportArgumentType]
return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
else:
raise TypeError(f'Invalid model: {model}')

Expand Down
Loading