From a106229568d7c5b16de1c763ef6921b2a0406be5 Mon Sep 17 00:00:00 2001 From: Sunil Sattiraju Date: Mon, 14 Oct 2024 05:28:53 +0800 Subject: [PATCH] Support structured output (#3732) * Support structured output * use ruff format * add type checking for cookbook * add the notebook to index.md * fix the type error * pass response_format explicitly * remove casting * ensure type are correct * seperate response_format arg * fix type and resolve pyright errors --------- Co-authored-by: Eric Zhu --- .../core-user-guide/cookbook/index.md | 3 +- .../cookbook/structured-output-agent.ipynb | 161 ++++++++++++++++++ .../components/models/_openai_client.py | 104 +++++++++-- 3 files changed, 250 insertions(+), 18 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/structured-output-agent.ipynb diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/index.md b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/index.md index 1141b7262ddb..711b67f04845 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/index.md @@ -13,4 +13,5 @@ local-llms-ollama-litellm instrumenting topic-subscription-scenarios azure-container-code-executor -``` \ No newline at end of file +structured-output-agent +``` diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/structured-output-agent.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/structured-output-agent.ipynb new file mode 100644 index 000000000000..5c0a1c5c59d5 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/structured-output-agent.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Strcutured output using GPT-4o models\n", + "\n", + "This cookbook demonstrates how to obtain structured output using GPT-4o models. The OpenAI beta client SDK provides a parse helper that allows you to use your own Pydantic model, eliminating the need to define a JSON schema. This approach is recommended for supported models.\n", + "\n", + "Currently, this feature is supported for:\n", + "\n", + "- gpt-4o-mini on OpenAI\n", + "- gpt-4o-2024-08-06 on OpenAI\n", + "- gpt-4o-2024-08-06 on Azure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a simple message type that carries explanation and output for a Math problem" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel\n", + "\n", + "\n", + "class MathReasoning(BaseModel):\n", + " class Step(BaseModel):\n", + " explanation: str\n", + " output: str\n", + "\n", + " steps: list[Step]\n", + " final_answer: str" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Set the environment variable\n", + "os.environ[\"AZURE_OPENAI_ENDPOINT\"] = \"https://YOUR_ENDPOINT_DETAILS.openai.azure.com/\"\n", + "os.environ[\"AZURE_OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n", + "os.environ[\"AZURE_OPENAI_DEPLOYMENT_NAME\"] = \"gpt-4o-2024-08-06\"\n", + "os.environ[\"AZURE_OPENAI_API_VERSION\"] = \"2024-08-01-preview\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from typing import Optional\n", + "\n", + "from autogen_core.components.models import AzureOpenAIChatCompletionClient, UserMessage\n", + "\n", + "\n", + "# Function to get environment variable and ensure it is not None\n", + "def get_env_variable(name: str) -> str:\n", + " value = os.getenv(name)\n", + " if value is None:\n", + " raise ValueError(f\"Environment variable {name} is not set\")\n", + " return value\n", + "\n", + "\n", + "# Create the client with type-checked environment variables\n", + "client = AzureOpenAIChatCompletionClient(\n", + " model=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n", + " api_version=get_env_variable(\"AZURE_OPENAI_API_VERSION\"),\n", + " azure_endpoint=get_env_variable(\"AZURE_OPENAI_ENDPOINT\"),\n", + " api_key=get_env_variable(\"AZURE_OPENAI_API_KEY\"),\n", + " model_capabilities={\n", + " \"vision\": False,\n", + " \"function_calling\": True,\n", + " \"json_output\": True,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'steps': [{'explanation': 'Start by aligning the numbers vertically.', 'output': '\\n 16\\n+ 32'}, {'explanation': 'Add the units digits: 6 + 2 = 8.', 'output': '\\n 16\\n+ 32\\n 8'}, {'explanation': 'Add the tens digits: 1 + 3 = 4.', 'output': '\\n 16\\n+ 32\\n 48'}], 'final_answer': '48'}\n" + ] + }, + { + "data": { + "text/plain": [ + "MathReasoning(steps=[Step(explanation='Start by aligning the numbers vertically.', output='\\n 16\\n+ 32'), Step(explanation='Add the units digits: 6 + 2 = 8.', output='\\n 16\\n+ 32\\n 8'), Step(explanation='Add the tens digits: 1 + 3 = 4.', output='\\n 16\\n+ 32\\n 48')], final_answer='48')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define the user message\n", + "messages = [\n", + " UserMessage(content=\"What is 16 + 32?\", source=\"user\"),\n", + "]\n", + "\n", + "# Call the create method on the client, passing the messages and additional arguments\n", + "# The extra_create_args dictionary includes the response format as MathReasoning model we defined above\n", + "# Providing the response format and pydantic model will use the new parse method from beta SDK\n", + "response = await client.create(messages=messages, extra_create_args={\"response_format\": MathReasoning})\n", + "\n", + "# Ensure the response content is a valid JSON string before loading it\n", + "response_content: Optional[str] = response.content if isinstance(response.content, str) else None\n", + "if response_content is None:\n", + " raise ValueError(\"Response content is not a valid JSON string\")\n", + "\n", + "# Print the response content after loading it as JSON\n", + "print(json.loads(response_content))\n", + "\n", + "# Validate the response content with the MathReasoning model\n", + "MathReasoning.model_validate(json.loads(response_content))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py b/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py index b2232756b163..2992c8a600c1 100644 --- a/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py +++ b/python/packages/autogen-core/src/autogen_core/components/models/_openai_client.py @@ -5,6 +5,7 @@ import math import re import warnings +from asyncio import Task from typing import ( Any, AsyncGenerator, @@ -14,6 +15,7 @@ Optional, Sequence, Set, + Type, Union, cast, ) @@ -21,6 +23,7 @@ import tiktoken from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ( + ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionContentPartParam, ChatCompletionContentPartTextParam, @@ -31,9 +34,13 @@ ChatCompletionToolMessageParam, ChatCompletionToolParam, ChatCompletionUserMessageParam, + ParsedChatCompletion, + ParsedChoice, completion_create_params, ) +from openai.types.chat.chat_completion import Choice from openai.types.shared_params import FunctionDefinition, FunctionParameters +from pydantic import BaseModel from typing_extensions import Unpack from ...application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME @@ -279,10 +286,10 @@ def convert_tools( type="function", function=FunctionDefinition( name=tool_schema["name"], - description=tool_schema["description"] if "description" in tool_schema else "", - parameters=cast(FunctionParameters, tool_schema["parameters"]) - if "parameters" in tool_schema - else {}, + description=(tool_schema["description"] if "description" in tool_schema else ""), + parameters=( + cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {} + ), ), ) ) @@ -365,6 +372,24 @@ async def create( create_args = self._create_args.copy() create_args.update(extra_create_args) + # Declare use_beta_client + use_beta_client: bool = False + response_format_value: Optional[Type[BaseModel]] = None + + if "response_format" in create_args: + value = create_args["response_format"] + # If value is a Pydantic model class, use the beta client + if isinstance(value, type) and issubclass(value, BaseModel): + response_format_value = value + use_beta_client = True + else: + # response_format_value is not a Pydantic model class + use_beta_client = False + response_format_value = None + + # Remove 'response_format' from create_args to prevent passing it twice + create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"} + # TODO: allow custom handling. # For now we raise an error if images are present and vision is not supported if self.capabilities["vision"] is False: @@ -390,24 +415,69 @@ async def create( if self.capabilities["function_calling"] is False and len(tools) > 0: raise ValueError("Model does not support function calling") - + future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]] if len(tools) > 0: converted_tools = convert_tools(tools) - future = asyncio.ensure_future( - self._client.chat.completions.create( - messages=oai_messages, - stream=False, - tools=converted_tools, - **create_args, + if use_beta_client: + # Pass response_format_value if it's not None + if response_format_value is not None: + future = asyncio.ensure_future( + self._client.beta.chat.completions.parse( + messages=oai_messages, + tools=converted_tools, + response_format=response_format_value, + **create_args_no_response_format, + ) + ) + else: + future = asyncio.ensure_future( + self._client.beta.chat.completions.parse( + messages=oai_messages, + tools=converted_tools, + **create_args_no_response_format, + ) + ) + else: + future = asyncio.ensure_future( + self._client.chat.completions.create( + messages=oai_messages, + stream=False, + tools=converted_tools, + **create_args, + ) ) - ) else: - future = asyncio.ensure_future( - self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args) - ) + if use_beta_client: + if response_format_value is not None: + future = asyncio.ensure_future( + self._client.beta.chat.completions.parse( + messages=oai_messages, + response_format=response_format_value, + **create_args_no_response_format, + ) + ) + else: + future = asyncio.ensure_future( + self._client.beta.chat.completions.parse( + messages=oai_messages, + **create_args_no_response_format, + ) + ) + else: + future = asyncio.ensure_future( + self._client.chat.completions.create( + messages=oai_messages, + stream=False, + **create_args, + ) + ) + if cancellation_token is not None: cancellation_token.link_future(future) - result = await future + result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future + if use_beta_client: + result = cast(ParsedChatCompletion[Any], result) + if result.usage is not None: logger.info( LLMCallEvent( @@ -430,7 +500,7 @@ async def create( ) # Limited to a single choice currently. - choice = result.choices[0] + choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0] if choice.finish_reason == "function_call": raise ValueError("Function calls are not supported in this context")