From 300349e0d355be8fb03585e1b7d660db157cd095 Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 12:34:18 +0100 Subject: [PATCH 1/8] Use OpenAI's Structured Outputs feature to prevent validation errors --- .gitignore | 1 + pydantic_ai_slim/pydantic_ai/models/openai.py | 38 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index e18c7079..53fa7716 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ env*/ /pydantic_ai_examples/.chat_app_messages.sqlite .cache/ .vscode/ +*.ipynb diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b84f5897..96053f7a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -10,6 +10,8 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never +from openai.types import ResponseFormatJSONSchema + from .. import UnexpectedModelBehavior, _utils, result from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( @@ -111,13 +113,12 @@ async def agent_model( ) -> AgentModel: check_allow_model_requests() tools = [self._map_tool_definition(r) for r in function_tools] - if result_tools: - tools += [self._map_tool_definition(r) for r in result_tools] + response_format = self._map_response_format(result_tools[0]) if result_tools else None return OpenAIAgentModel( self.client, self.model_name, - allow_text_result, tools, + response_format ) def name(self) -> str: @@ -134,6 +135,17 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: }, } + @staticmethod + def _map_response_format(f: ToolDefinition) -> dict: + return { + 'type': 'json_schema', + 'json_schema': { + 'name': f.name, + 'description': f.description, + 'schema': f.parameters_json_schema, + }, + } + @dataclass class OpenAIAgentModel(AgentModel): @@ -141,8 +153,8 @@ class OpenAIAgentModel(AgentModel): client: AsyncOpenAI model_name: OpenAIModelName - allow_text_result: bool tools: list[chat.ChatCompletionToolParam] + response_format: ResponseFormatJSONSchema | None async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None @@ -174,13 +186,7 @@ async def _completions_create( self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: # standalone function to make it easier to override - if not self.tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not self.allow_text_result: - tool_choice = 'required' - else: - tool_choice = 'auto' - + tool_choice = 'auto' if self.tools else None openai_messages = list(chain(*(self._map_message(m) for m in messages))) model_settings = model_settings or {} @@ -188,6 +194,7 @@ async def _completions_create( return await self.client.chat.completions.create( model=self.model_name, messages=openai_messages, + response_format=self.response_format or NOT_GIVEN, n=1, parallel_tool_calls=True if self.tools else NOT_GIVEN, tools=self.tools or NOT_GIVEN, @@ -200,14 +207,17 @@ async def _completions_create( timeout=model_settings.get('timeout', NOT_GIVEN), ) - @staticmethod - def _process_response(response: chat.ChatCompletion) -> ModelResponse: + def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) choice = response.choices[0] items: list[ModelResponsePart] = [] if choice.message.content is not None: - items.append(TextPart(choice.message.content)) + if self.response_format: + name = self.response_format['json_schema']['name'] + items.append(ToolCallPart.from_raw_args(name, choice.message.content)) + else: + items.append(TextPart(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) From 5f6eebca1165b320a71ad637782c8741d8fce21d Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 12:44:04 +0100 Subject: [PATCH 2/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 96053f7a..661d6067 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -114,12 +114,7 @@ async def agent_model( check_allow_model_requests() tools = [self._map_tool_definition(r) for r in function_tools] response_format = self._map_response_format(result_tools[0]) if result_tools else None - return OpenAIAgentModel( - self.client, - self.model_name, - tools, - response_format - ) + return OpenAIAgentModel(self.client, self.model_name, tools, response_format) def name(self) -> str: return f'openai:{self.model_name}' From 3c32dcec8ec87c54b6f7eddd28da59c103c6e124 Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 12:46:13 +0100 Subject: [PATCH 3/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 661d6067..979cc573 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -9,7 +9,6 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never - from openai.types import ResponseFormatJSONSchema from .. import UnexpectedModelBehavior, _utils, result From 70cfeecd49a2d4f20eb8530ef9aee63d6099dd4b Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 12:48:16 +0100 Subject: [PATCH 4/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 979cc573..df0c1ee5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -5,11 +5,11 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from itertools import chain +from openai.types import ResponseFormatJSONSchema from typing import Literal, Union, overload from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from openai.types import ResponseFormatJSONSchema from .. import UnexpectedModelBehavior, _utils, result from .._utils import guard_tool_call_id as _guard_tool_call_id From c33836bae0451e6033d89cddf89fbed7cc7e2984 Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 15:52:10 +0100 Subject: [PATCH 5/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index df0c1ee5..7c70b450 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -5,7 +5,6 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from itertools import chain -from openai.types import ResponseFormatJSONSchema from typing import Literal, Union, overload from httpx import AsyncClient as AsyncHTTPClient @@ -43,6 +42,7 @@ from openai.types import ChatModel, chat from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall + from openai.types.shared_params.response_format_json_schema import ResponseFormatJSONSchema except ImportError as _import_error: raise ImportError( 'Please install `openai` to use the OpenAI model, ' @@ -130,14 +130,14 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: } @staticmethod - def _map_response_format(f: ToolDefinition) -> dict: + def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema: return { 'type': 'json_schema', 'json_schema': { 'name': f.name, 'description': f.description, - 'schema': f.parameters_json_schema, - }, + 'schema': f.parameters_json_schema + } } @@ -185,7 +185,7 @@ async def _completions_create( model_settings = model_settings or {} - return await self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model_name, messages=openai_messages, response_format=self.response_format or NOT_GIVEN, @@ -201,6 +201,8 @@ async def _completions_create( timeout=model_settings.get('timeout', NOT_GIVEN), ) + return response + def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) From 0954e455535f834eb59c09b45bd99d31e0fe41ad Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 16:10:19 +0100 Subject: [PATCH 6/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 7c70b450..8b18a38a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -133,11 +133,7 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema: return { 'type': 'json_schema', - 'json_schema': { - 'name': f.name, - 'description': f.description, - 'schema': f.parameters_json_schema - } + 'json_schema': {'name': f.name, 'description': f.description, 'schema': f.parameters_json_schema} } From 3f81f1a21eb98b84fcea282dc9cb709820d7bdc8 Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 16:12:28 +0100 Subject: [PATCH 7/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 8b18a38a..5b67ce77 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -133,7 +133,11 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema: return { 'type': 'json_schema', - 'json_schema': {'name': f.name, 'description': f.description, 'schema': f.parameters_json_schema} + 'json_schema': { + 'name': f.name, + 'description': f.description, + 'schema': f.parameters_json_schema + }, } From 8057da16724fd598881757bee275a59ae026dfb8 Mon Sep 17 00:00:00 2001 From: Renke Hohl Date: Fri, 20 Dec 2024 16:13:39 +0100 Subject: [PATCH 8/8] Reformat openai.py --- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 5b67ce77..e255bddb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -133,11 +133,7 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: def _map_response_format(f: ToolDefinition) -> ResponseFormatJSONSchema: return { 'type': 'json_schema', - 'json_schema': { - 'name': f.name, - 'description': f.description, - 'schema': f.parameters_json_schema - }, + 'json_schema': {'name': f.name, 'description': f.description, 'schema': f.parameters_json_schema}, }