Skip to content

Commit

Permalink
Use role 'developer' for OpenAIModel system prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Dec 21, 2024
1 parent a193111 commit cf9444e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 33 deletions.
49 changes: 25 additions & 24 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..messages import (
ModelMessage,
ModelRequest,
ModelRequestPart,
ModelResponse,
ModelResponsePart,
RetryPromptPart,
Expand Down Expand Up @@ -113,12 +114,7 @@ async def agent_model(
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return OpenAIAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)
return OpenAIAgentModel(self.client, self.model_name, allow_text_result, tools)

def name(self) -> str:
return f'openai:{self.model_name}'
Expand Down Expand Up @@ -271,27 +267,32 @@ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMess
@classmethod
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
for part in message.parts:
if isinstance(part, SystemPromptPart):
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
elif isinstance(part, ToolReturnPart):
yield chat.ChatCompletionToolMessageParam(
if part := cls._map_model_request_part(part):
yield part

@classmethod
def _map_model_request_part(cls, part: ModelRequestPart) -> chat.ChatCompletionMessageParam | None:
if isinstance(part, SystemPromptPart):
return chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
elif isinstance(part, UserPromptPart):
return chat.ChatCompletionUserMessageParam(role='user', content=part.content)
elif isinstance(part, ToolReturnPart):
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
content=part.model_response_str(),
)
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
return chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
else:
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
content=part.model_response_str(),
content=part.model_response(),
)
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
else:
yield chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
content=part.model_response(),
)
else:
assert_never(part)
else:
assert_never(part)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
]

[project.optional-dependencies]
openai = ["openai>=1.54.3"]
openai = ["openai>=1.59.0"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
groq = ["groq>=0.12.0"]
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ classifiers = [
]
requires-python = ">=3.9"

dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.14"]
dependencies = [
"pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.14",
]

[project.urls]
Homepage = "https://ai.pydantic.dev"
Expand All @@ -52,6 +54,7 @@ logfire = ["logfire>=2.3"]
[tool.uv.sources]
pydantic-ai-slim = { workspace = true }
pydantic-ai-examples = { workspace = true }
openai = { git = "https://github.com/openai/openai-python" }

[tool.uv.workspace]
members = ["pydantic_ai_slim", "pydantic_ai_examples"]
Expand Down
10 changes: 3 additions & 7 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cf9444e

Please sign in to comment.