From cf9444eff8469e6251e18677f0e40ab1741ff30a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sat, 21 Dec 2024 15:12:25 -0700 Subject: [PATCH] Use role 'developer' for OpenAIModel system prompts --- pydantic_ai_slim/pydantic_ai/models/openai.py | 49 ++++++++++--------- pydantic_ai_slim/pyproject.toml | 2 +- pyproject.toml | 5 +- uv.lock | 10 ++-- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b84f5897..b3db8c69 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -15,6 +15,7 @@ from ..messages import ( ModelMessage, ModelRequest, + ModelRequestPart, ModelResponse, ModelResponsePart, RetryPromptPart, @@ -113,12 +114,7 @@ async def agent_model( tools = [self._map_tool_definition(r) for r in function_tools] if result_tools: tools += [self._map_tool_definition(r) for r in result_tools] - return OpenAIAgentModel( - self.client, - self.model_name, - allow_text_result, - tools, - ) + return OpenAIAgentModel(self.client, self.model_name, allow_text_result, tools) def name(self) -> str: return f'openai:{self.model_name}' @@ -271,27 +267,32 @@ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMess @classmethod def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: for part in message.parts: - if isinstance(part, SystemPromptPart): - yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) - elif isinstance(part, UserPromptPart): - yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) - elif isinstance(part, ToolReturnPart): - yield chat.ChatCompletionToolMessageParam( + if part := cls._map_model_request_part(part): + yield part + + @classmethod + def _map_model_request_part(cls, part: ModelRequestPart) -> chat.ChatCompletionMessageParam | None: + if isinstance(part, SystemPromptPart): + return chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content) + elif isinstance(part, UserPromptPart): + return chat.ChatCompletionUserMessageParam(role='user', content=part.content) + elif isinstance(part, ToolReturnPart): + return chat.ChatCompletionToolMessageParam( + role='tool', + tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'), + content=part.model_response_str(), + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + return chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) + else: + return chat.ChatCompletionToolMessageParam( role='tool', tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'), - content=part.model_response_str(), + content=part.model_response(), ) - elif isinstance(part, RetryPromptPart): - if part.tool_name is None: - yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) - else: - yield chat.ChatCompletionToolMessageParam( - role='tool', - tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'), - content=part.model_response(), - ) - else: - assert_never(part) + else: + assert_never(part) @dataclass diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index dd12b95d..87194d81 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ ] [project.optional-dependencies] -openai = ["openai>=1.54.3"] +openai = ["openai>=1.59.0"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] groq = ["groq>=0.12.0"] diff --git a/pyproject.toml b/pyproject.toml index b1a3ffb3..91bda063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,9 @@ classifiers = [ ] requires-python = ">=3.9" -dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.14"] +dependencies = [ + "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.14", +] [project.urls] Homepage = "https://ai.pydantic.dev" @@ -52,6 +54,7 @@ logfire = ["logfire>=2.3"] [tool.uv.sources] pydantic-ai-slim = { workspace = true } pydantic-ai-examples = { workspace = true } +openai = { git = "https://github.com/openai/openai-python" } [tool.uv.workspace] members = ["pydantic_ai_slim", "pydantic_ai_examples"] diff --git a/uv.lock b/uv.lock index cb7a7f10..fa11b423 100644 --- a/uv.lock +++ b/uv.lock @@ -1198,8 +1198,8 @@ wheels = [ [[package]] name = "openai" -version = "1.55.0" -source = { registry = "https://pypi.org/simple" } +version = "1.59.0" +source = { git = "https://github.com/openai/openai-python#89d49335a02ac231925e5a514659c93322f29526" } dependencies = [ { name = "anyio" }, { name = "distro" }, @@ -1210,10 +1210,6 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/7b/b1a3b6fa17dc523c603121dd23615bcd895a9fc3ab23be92307b9347bc50/openai-1.55.0.tar.gz", hash = "sha256:6c0975ac8540fe639d12b4ff5a8e0bf1424c844c4a4251148f59f06c4b2bd5db", size = 313963 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/b2/9d8939c9ef73e6a2d5cb3366ef3fabd728a8de4729210d9af785c0edc6ec/openai-1.55.0-py3-none-any.whl", hash = "sha256:446e08918f8dd70d8723274be860404c8c7cc46b91b93bbc0ef051f57eb503c1", size = 389528 }, -] [[package]] name = "opentelemetry-api" @@ -1713,7 +1709,7 @@ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "logfire-api", specifier = ">=1.2.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, - { name = "openai", marker = "extra == 'openai'", specifier = ">=1.54.3" }, + { name = "openai", marker = "extra == 'openai'", git = "https://github.com/openai/openai-python" }, { name = "pydantic", specifier = ">=2.10" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.3" }, ]