Skip to content

Commit

Permalink
remove PlainResponseForbidden (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 20, 2024
1 parent 475fcaf commit ae2df48
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 22 deletions.
7 changes: 6 additions & 1 deletion pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,12 @@ async def _handle_model_response(
if self._allow_text_result:
return _utils.Either(left=cast(result.ResultData, model_response.content))
else:
return _utils.Either(right=[_messages.PlainResponseForbidden()])
self._incr_result_retry()
assert self._result_tool is not None
response = _messages.UserPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return _utils.Either(right=[response])
elif model_response.role == 'llm-tool-calls':
if self._result_tool is not None:
# if there's a result schema, and any of the calls match that name, return the result
Expand Down
13 changes: 1 addition & 12 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,6 @@ def llm_response(self) -> str:
return f'{description}\n\nFix the errors and try again.'


@dataclass
class PlainResponseForbidden:
# TODO remove and replace with ToolRetry
timestamp: datetime = field(default_factory=datetime.now)
role: Literal['plain-response-forbidden'] = 'plain-response-forbidden'

@staticmethod
def llm_response() -> str:
return 'Plain text responses are not allowed, please call one of the functions instead.'


@dataclass
class LLMResponse:
content: str
Expand Down Expand Up @@ -105,6 +94,6 @@ class LLMToolCalls:


LLMMessage = Union[LLMResponse, LLMToolCalls]
Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, PlainResponseForbidden, LLMMessage]
Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, LLMMessage]

MessagesTypeAdapter = pydantic.TypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
2 changes: 0 additions & 2 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiConte
elif m.role == 'tool-retry':
# ToolRetry ->
return _utils.Either(right=_GeminiContent.function_retry(m))
elif m.role == 'plain-response-forbidden':
return _utils.Either(right=_GeminiContent.user_text(m.llm_response()))
elif m.role == 'llm-response':
# LLMResponse ->
return _utils.Either(right=_GeminiContent.model_text(m.content))
Expand Down
6 changes: 0 additions & 6 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam:
role='assistant',
tool_calls=[_map_tool_call(t) for t in message.calls],
)
elif message.role == 'plain-response-forbidden':
# PlainResponseForbidden ->
return chat.ChatCompletionUserMessageParam(
role='user',
content=message.llm_response(),
)
else:
assert_never(message)

Expand Down
46 changes: 45 additions & 1 deletion tests/test_result_validation.py → tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
from pydantic import BaseModel

from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import LLMMessage, LLMToolCalls, Message, ToolCall, ToolRetry, UserPrompt
from pydantic_ai.messages import (
ArgsJson,
LLMMessage,
LLMResponse,
LLMToolCalls,
Message,
ToolCall,
ToolRetry,
UserPrompt,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from tests.conftest import IsNow

Expand Down Expand Up @@ -107,3 +116,38 @@ def validate_result(r: Foo) -> Foo:
LLMToolCalls(calls=[ToolCall.from_json('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow()),
]
)


def test_plain_response():
call_index = 0

def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage:
nonlocal call_index

assert info.result_tool is not None
call_index += 1
if call_index == 1:
return LLMResponse(content='hello')
else:
args_json = '{"response": ["foo", "bar"]}'
return LLMToolCalls(calls=[ToolCall.from_json(info.result_tool.name, args_json)])

agent = Agent(FunctionModel(return_tuple), deps=None, result_type=tuple[str, str])

result = agent.run_sync('Hello')
assert result.response == ('foo', 'bar')
assert call_index == 2
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
LLMResponse(content='hello', timestamp=IsNow()),
UserPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
timestamp=IsNow(),
),
LLMToolCalls(
calls=[ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"response": ["foo", "bar"]}'))],
timestamp=IsNow(),
),
]
)

0 comments on commit ae2df48

Please sign in to comment.