-
Notifications
You must be signed in to change notification settings - Fork 371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use OpenAI's Structured Outputs feature to prevent validation errors #514
base: main
Are you sure you want to change the base?
Changes from all commits
300349e
5f6eebc
3c32dce
70cfeec
c33836b
0954e45
3f81f1a
8057da1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ env*/ | |
/pydantic_ai_examples/.chat_app_messages.sqlite | ||
.cache/ | ||
.vscode/ | ||
*.ipynb |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
from openai.types import ChatModel, chat | ||
from openai.types.chat import ChatCompletionChunk | ||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall | ||
from openai.types.shared_params.response_format_json_schema import ResponseFormatJSONSchema | ||
except ImportError as _import_error: | ||
raise ImportError( | ||
'Please install `openai` to use the OpenAI model, ' | ||
|
@@ -111,14 +112,8 @@ async def agent_model( | |
) -> AgentModel: | ||
check_allow_model_requests() | ||
tools = [self._map_tool_definition(r) for r in function_tools] | ||
if result_tools: | ||
tools += [self._map_tool_definition(r) for r in result_tools] | ||
return OpenAIAgentModel( | ||
self.client, | ||
self.model_name, | ||
allow_text_result, | ||
tools, | ||
) | ||
response_format = self._map_response_format(result_tools[0]) if result_tools else None | ||
return OpenAIAgentModel(self.client, self.model_name, tools, response_format) | ||
|
||
def name(self) -> str: | ||
return f'openai:{self.model_name}' | ||
|
@@ -134,15 +129,22 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: | |
}, | ||
} | ||
|
||
@staticmethod | ||
def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema: | ||
return { | ||
'type': 'json_schema', | ||
'json_schema': {'name': f.name, 'description': f.description, 'schema': f.parameters_json_schema}, | ||
} | ||
|
||
|
||
@dataclass | ||
class OpenAIAgentModel(AgentModel): | ||
"""Implementation of `AgentModel` for OpenAI models.""" | ||
|
||
client: AsyncOpenAI | ||
model_name: OpenAIModelName | ||
allow_text_result: bool | ||
tools: list[chat.ChatCompletionToolParam] | ||
response_format: ResponseFormatJSONSchema | None | ||
|
||
async def request( | ||
self, messages: list[ModelMessage], model_settings: ModelSettings | None | ||
|
@@ -174,20 +176,15 @@ async def _completions_create( | |
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None | ||
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: | ||
# standalone function to make it easier to override | ||
if not self.tools: | ||
tool_choice: Literal['none', 'required', 'auto'] | None = None | ||
elif not self.allow_text_result: | ||
tool_choice = 'required' | ||
else: | ||
tool_choice = 'auto' | ||
|
||
tool_choice = 'auto' if self.tools else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to respect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, my changes ignore the |
||
openai_messages = list(chain(*(self._map_message(m) for m in messages))) | ||
|
||
model_settings = model_settings or {} | ||
|
||
return await self.client.chat.completions.create( | ||
response = await self.client.chat.completions.create( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need this change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we really don't need that. |
||
model=self.model_name, | ||
messages=openai_messages, | ||
response_format=self.response_format or NOT_GIVEN, | ||
n=1, | ||
parallel_tool_calls=True if self.tools else NOT_GIVEN, | ||
tools=self.tools or NOT_GIVEN, | ||
|
@@ -200,14 +197,19 @@ async def _completions_create( | |
timeout=model_settings.get('timeout', NOT_GIVEN), | ||
) | ||
|
||
@staticmethod | ||
def _process_response(response: chat.ChatCompletion) -> ModelResponse: | ||
return response | ||
|
||
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: | ||
"""Process a non-streamed response, and prepare a message to return.""" | ||
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) | ||
choice = response.choices[0] | ||
items: list[ModelResponsePart] = [] | ||
if choice.message.content is not None: | ||
items.append(TextPart(choice.message.content)) | ||
if self.response_format: | ||
name = self.response_format['json_schema']['name'] | ||
items.append(ToolCallPart.from_raw_args(name, choice.message.content)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried this locally, open_ai requires tool_call_id to be set, which needs to be passed through this method, choice does not really come with an id though? either we generate an id or use the choice index perhaps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I run into the same case recently. Using the At the moment, PydanticAI wraps structured responses in ToolCallParts, which needs to be changed in my opinion, to better differentiate between actual tool calls and structured responses. For solving this issue, we need to find a way for the agent to return the resonse instead of making a tool call from it. |
||
else: | ||
items.append(TextPart(choice.message.content)) | ||
if choice.message.tool_calls is not None: | ||
for c in choice.message.tool_calls: | ||
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are multiple result tools (which I believe is the case when the result type is a union), we would definitely need to ensure that all the tool calls are present in the final
response_format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also
(Or, find a way for the response format to allow raw strings)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding multiple result tools: I couldn't find a case where the number of result tools is more than 1. Even in the case of a union. If you have any examples, where we have more than one result tools, please tell me.
Regarding the
allow_text_results
parameter: Please checkout @samuelcolvin's comment. The response format should allow raw strings, if we setresult_type=str
.