Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds dynamic to system_prompt decorator, allowing reevaluation #560

Merged
merged 32 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c248ade
Adds dynamic to system_prompt decorator, allowing reevaluation
josead Dec 28, 2024
71c8808
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Dec 30, 2024
f63ede3
lint fix
josead Dec 30, 2024
11b5abb
removing useless overload
josead Dec 30, 2024
38c5f62
Adds tests better naming
josead Dec 30, 2024
f2b793f
lint
josead Dec 30, 2024
b56bb1c
removes unused
josead Dec 30, 2024
0453499
Adds dynamic id to parts, referencing the runner that created it.
josead Jan 1, 2025
ade332e
Merge branch 'main' into main
josead Jan 1, 2025
2e63819
fix tests
josead Jan 1, 2025
49bf36f
fix lint
josead Jan 1, 2025
e8f798e
Fix merge
josead Jan 1, 2025
01ab9f4
Fixing overload
josead Jan 2, 2025
20cd93a
Update pydantic_ai_slim/pydantic_ai/messages.py
josead Jan 2, 2025
59f0ef4
Update pydantic_ai_slim/pydantic_ai/messages.py
josead Jan 2, 2025
2c84e65
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
260a610
Removes unused system prompts
josead Jan 2, 2025
2a73e13
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
78db2c2
Adds changes to use qual name for ref
josead Jan 2, 2025
3c1cc83
lint changes, make
josead Jan 2, 2025
08f041d
Adds default value and fix tests
josead Jan 2, 2025
5a10407
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
440b443
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
3d14344
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
e69a351
adds callable and dynamic assert on func none
josead Jan 2, 2025
feda9ff
adds the assert to the correct part
josead Jan 2, 2025
9da26c5
Adds Callable as return type
josead Jan 2, 2025
68380b8
Adds dynamic ref in system prompt
josead Jan 2, 2025
740ea26
Fix examples tests
josead Jan 2, 2025
956f143
Changes on docs to support dynamic
josead Jan 2, 2025
f15ba42
Adds correct overload
josead Jan 5, 2025
90a6069
tweaks to docs
samuelcolvin Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
@dataclass
class SystemPromptRunner(Generic[AgentDeps]):
function: SystemPromptFunc[AgentDeps]
dynamic: bool = False
_takes_ctx: bool = field(init=False)
_is_async: bool = field(init=False)

Expand Down
67 changes: 54 additions & 13 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,16 +526,37 @@ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@overload
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...

@overload
def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None, /, *, dynamic: bool = False
josead marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:

if func is None:
def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func_, dynamic=dynamic))
return func_
return decorator
else:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
return func

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self,
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
/,
*,
dynamic: bool = False,
) -> Any:
"""Decorator to register a system prompt function.

Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
Can decorate a sync or async functions.

Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
the type of the function, see `tests/typed_agent.py` for tests.
Args:
func: The function to decorate
dynamic: If True, the system prompt will be reevaluated when messages_history is present.
josead marked this conversation as resolved.
Show resolved Hide resolved

Example:
```python
Expand All @@ -547,17 +568,21 @@ def system_prompt(
def simple_system_prompt() -> str:
return 'foobar'

@agent.system_prompt
@agent.system_prompt(dynamic=True)
async def async_system_prompt(ctx: RunContext[str]) -> str:
return f'{ctx.deps} is the best'

result = agent.run_sync('foobar', deps='spam')
print(result.data)
#> success (no tool calls)
```
"""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func
if func is None:
def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func_, dynamic=dynamic))
return func_
return decorator
else:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
return func

@overload
def result_validator(
Expand Down Expand Up @@ -830,8 +855,8 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
)

async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
) -> list[_messages.ModelMessage]:
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
josead marked this conversation as resolved.
Show resolved Hide resolved
) -> list[_messages.ModelMessage]:
try:
messages = _messages_ctx_var.get()
except LookupError:
Expand All @@ -846,6 +871,22 @@ async def _prepare_messages(
if message_history:
# shallow copy messages
messages.extend(message_history)

# If there are any dynamic system prompts, we need to reevaluate them
if any(runner.dynamic for runner in self._system_prompt_functions):
# Get fresh system prompts
new_sys_parts = await self._sys_parts(run_context)

# Replace the system prompts in the existing messages
for msg in messages:
if isinstance(msg, _messages.ModelRequest):
# Keep non-system parts and add new system parts
non_system_parts = [
part for part in msg.parts
if not isinstance(part, _messages.SystemPromptPart)
]
msg.parts = new_sys_parts + non_system_parts

messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
else:
parts = await self._sys_parts(run_context)
Expand Down
230 changes: 230 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,3 +1230,233 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
ModelResponse(parts=[TextPart(content='success (no tool calls)')], timestamp=IsNow(tz=timezone.utc)),
]
)


def test_messages_history_reevaluate_system_prompt_clean(set_event_loop: None):
agent = Agent('test', system_prompt='Foobar')

dynamic_value = "A"

@agent.system_prompt
async def func(ctx):
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)

dynamic_value = "B"

@agent.system_prompt
async def func_two(ctx):
return dynamic_value + "!"

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="A", #Remains the same
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)

def test_messages_history_reevaluate_system_prompt_dirty(set_event_loop: None):
josead marked this conversation as resolved.
Show resolved Hide resolved
agent = Agent('test', system_prompt='Foobar')

dynamic_value = "A"

@agent.system_prompt(dynamic=True)
async def func(ctx):
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)

dynamic_value = "B"

@agent.system_prompt
async def func_two(ctx):
return "This is a new prompt, but it wont reach the model"

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="B", # Updated value since dirty
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)



def test_messages_history_reevaluate_system_prompt_dirty(set_event_loop: None):
agent = Agent('test', system_prompt='Foobar')

dynamic_value = "A"

@agent.system_prompt(dynamic=True)
async def func(ctx):
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)

dynamic_value = "B"

@agent.system_prompt(dynamic=True)
async def func_two(ctx):
return "This is a new prompt, and model will know"

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="B", # Updated value since dirty
part_kind='system-prompt'),
SystemPromptPart(
content="This is a new prompt, and model will know",
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
]
)


Loading