Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OpenAI's Structured Outputs feature to prevent validation errors #514

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ env*/
/pydantic_ai_examples/.chat_app_messages.sqlite
.cache/
.vscode/
*.ipynb
42 changes: 22 additions & 20 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from openai.types import ChatModel, chat
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from openai.types.shared_params.response_format_json_schema import ResponseFormatJSONSchema
except ImportError as _import_error:
raise ImportError(
'Please install `openai` to use the OpenAI model, '
Expand Down Expand Up @@ -111,14 +112,8 @@ async def agent_model(
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return OpenAIAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)
response_format = self._map_response_format(result_tools[0]) if result_tools else None
Copy link
Contributor

@dmontagu dmontagu Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are multiple result tools (which I believe is the case when the result type is a union), we would definitely need to ensure that all the tool calls are present in the final response_format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also

Suggested change
response_format = self._map_response_format(result_tools[0]) if result_tools else None
response_format = self._map_response_format(result_tools[0]) if result_tools and not allow_text_result else None

(Or, find a way for the response format to allow raw strings)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding multiple result tools: I couldn't find a case where the number of result tools is more than 1. Even in the case of a union. If you have any examples, where we have more than one result tools, please tell me.

Regarding the allow_text_results parameter: Please checkout @samuelcolvin's comment. The response format should allow raw strings, if we set result_type=str.

return OpenAIAgentModel(self.client, self.model_name, tools, response_format)

def name(self) -> str:
return f'openai:{self.model_name}'
Expand All @@ -134,15 +129,22 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
},
}

@staticmethod
def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema:
return {
'type': 'json_schema',
'json_schema': {'name': f.name, 'description': f.description, 'schema': f.parameters_json_schema},
}


@dataclass
class OpenAIAgentModel(AgentModel):
"""Implementation of `AgentModel` for OpenAI models."""

client: AsyncOpenAI
model_name: OpenAIModelName
allow_text_result: bool
tools: list[chat.ChatCompletionToolParam]
response_format: ResponseFormatJSONSchema | None

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
Expand Down Expand Up @@ -174,20 +176,15 @@ async def _completions_create(
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'

tool_choice = 'auto' if self.tools else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to respect allow_text_result, I think right now you're not.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my changes ignore the allow_text_result parameter. The reason for this is, that if we pass a result_type to the agent, text results will automatically be excluded. The output generation of the LLM will be constrained to the given schema. In the case of PydanticAI's Agents, the final response will then be wrapped in a ToolCallPart. In other words, the information provided by allow_text_results is implicitly given by the length of the result_tools parameter. Please tell me, if I missed anything.

openai_messages = list(chain(*(self._map_message(m) for m in messages)))

model_settings = model_settings or {}

return await self.client.chat.completions.create(
response = await self.client.chat.completions.create(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we really don't need that.

model=self.model_name,
messages=openai_messages,
response_format=self.response_format or NOT_GIVEN,
n=1,
parallel_tool_calls=True if self.tools else NOT_GIVEN,
tools=self.tools or NOT_GIVEN,
Expand All @@ -200,14 +197,19 @@ async def _completions_create(
timeout=model_settings.get('timeout', NOT_GIVEN),
)

@staticmethod
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
return response

def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
choice = response.choices[0]
items: list[ModelResponsePart] = []
if choice.message.content is not None:
items.append(TextPart(choice.message.content))
if self.response_format:
name = self.response_format['json_schema']['name']
items.append(ToolCallPart.from_raw_args(name, choice.message.content))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this locally, open_ai requires tool_call_id to be set, which needs to be passed through this method, choice does not really come with an id though? either we generate an id or use the choice index perhaps?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run into the same case recently. Using the response_format parameter when calling OpenAI, constrains the output generation to follow a given schema. The final response is not intended to be a tool call. As you can see, we use choice.message.content as arguments, which doesn't provide a tool call id.

At the moment, PydanticAI wraps structured responses in ToolCallParts, which needs to be changed in my opinion, to better differentiate between actual tool calls and structured responses.

For solving this issue, we need to find a way for the agent to return the resonse instead of making a tool call from it.

else:
items.append(TextPart(choice.message.content))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
Expand Down
Loading