From 9cec541630af22c2e2683d1f44a6ddd1d2e2425f Mon Sep 17 00:00:00 2001 From: Yiran Wu <32823396+kevin666aa@users.noreply.github.com> Date: Sat, 9 Dec 2023 22:28:13 -0500 Subject: [PATCH] Convert ChatCompletionMessage to Dict after completion (#791) * update * update * update signature * update * update * fix test funccall groupchat * reverse change * update * update * update * update * update --------- Co-authored-by: Qingyun Wu Co-authored-by: Chi Wang --- .github/workflows/build.yml | 2 +- .../agentchat/contrib/compressible_agent.py | 2 +- autogen/agentchat/conversable_agent.py | 7 ++- autogen/oai/client.py | 30 ++++++++--- test/agentchat/test_function_call.py | 2 +- test/oai/test_client.py | 50 +++++++++++++++++-- test/oai/test_client_stream.py | 8 +-- website/docs/Installation.md | 2 +- website/docs/Use-Cases/enhanced_inference.md | 6 +-- 9 files changed, 87 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0726e410bd0..8d4b84a301a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,7 +49,7 @@ jobs: - name: Coverage if: matrix.python-version == '3.10' run: | - pip install -e .[mathchat,test] + pip install -e .[test] pip uninstall -y openai coverage run -a -m pytest test --ignore=test/agentchat/contrib coverage xml diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py index f1de41512e9..dc8c80a01ff 100644 --- a/autogen/agentchat/contrib/compressible_agent.py +++ b/autogen/agentchat/contrib/compressible_agent.py @@ -403,7 +403,7 @@ def compress_messages( print(colored(f"Failed to compress the content due to {e}", "red"), flush=True) return False, None - compressed_message = self.client.extract_text_or_function_call(response)[0] + compressed_message = self.client.extract_text_or_completion_object(response)[0] assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}" if self.compress_config["verbose"]: print( diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 590d11afcdd..493a83da8a5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -631,7 +631,12 @@ def generate_oai_reply( response = client.create( context=messages[-1].pop("context", None), messages=self._oai_system_message + messages ) - return True, client.extract_text_or_function_call(response)[0] + + # TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged. + extracted_response = client.extract_text_or_completion_object(response)[0] + if not isinstance(extracted_response, str): + extracted_response = extracted_response.model_dump(mode="dict") + return True, extracted_response async def a_generate_oai_reply( self, diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fedcea4d65b..a4714075b0f 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -10,7 +10,9 @@ from autogen.oai.openai_utils import get_key, oai_price1k from autogen.token_count_utils import count_token +TOOL_ENABLED = False try: + import openai from openai import OpenAI, APIError from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice @@ -18,6 +20,8 @@ from openai.types.completion_usage import CompletionUsage import diskcache + if openai.__version__ >= "1.1.0": + TOOL_ENABLED = True ERROR = None except ImportError: ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") @@ -205,7 +209,7 @@ def create(self, **config): ```python def yes_or_no_filter(context, response): return context.get("yes_or_no_choice", False) is False or any( - text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response) + text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response) ) ``` @@ -442,21 +446,33 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 @classmethod - def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]: - """Extract the text or function calls from a completion or chat response. + def extract_text_or_completion_object( + cls, response: ChatCompletion | Completion + ) -> Union[List[str], List[ChatCompletionMessage]]: + """Extract the text or ChatCompletion objects from a completion or chat response. Args: response (ChatCompletion | Completion): The response from openai. Returns: - A list of text or function calls in the responses. + A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. """ choices = response.choices if isinstance(response, Completion): return [choice.text for choice in choices] - return [ - choice.message if choice.message.function_call is not None else choice.message.content for choice in choices - ] + + if TOOL_ENABLED: + return [ + choice.message + if choice.message.function_call is not None or choice.message.tool_calls is not None + else choice.message.content + for choice in choices + ] + else: + return [ + choice.message if choice.message.function_call is not None else choice.message.content + for choice in choices + ] # TODO: logging diff --git a/test/agentchat/test_function_call.py b/test/agentchat/test_function_call.py index 3c2db3b5e48..cfe606b08eb 100644 --- a/test/agentchat/test_function_call.py +++ b/test/agentchat/test_function_call.py @@ -48,7 +48,7 @@ def test_eval_math_responses(): functions=functions, ) print(response) - responses = client.extract_text_or_function_call(response) + responses = client.extract_text_or_completion_object(response) print(responses[0]) function_call = responses[0].function_call name, arguments = function_call.name, json.loads(function_call.arguments) diff --git a/test/oai/test_client.py b/test/oai/test_client.py index 45033123846..aec241697ec 100644 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -2,12 +2,18 @@ from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai from test_utils import OAI_CONFIG_LIST, KEY_LOC +TOOL_ENABLED = False try: from openai import OpenAI + from openai.types.chat.chat_completion import ChatCompletionMessage except ImportError: skip = True else: skip = False + import openai + + if openai.__version__ >= "1.1.0": + TOOL_ENABLED = True @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -24,7 +30,44 @@ def test_aoai_chat_completion(): # response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) + + +@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed") +def test_oai_tool_calling_extraction(): + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, + file_location=KEY_LOC, + filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]}, + ) + client = OpenAIWrapper(config_list=config_list) + response = client.create( + messages=[ + { + "role": "user", + "content": "What is the weather in San Francisco?", + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "getCurrentWeather", + "description": "Get the weather in location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + }, + } + ], + ) + print(response) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -36,7 +79,7 @@ def test_chat_completion(): client = OpenAIWrapper(config_list=config_list) response = client.create(messages=[{"role": "user", "content": "1+1="}]) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -45,7 +88,7 @@ def test_completion(): client = OpenAIWrapper(config_list=config_list) response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct") print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -96,6 +139,7 @@ def test_usage_summary(): if __name__ == "__main__": test_aoai_chat_completion() + test_oai_tool_calling_extraction() test_chat_completion() test_completion() test_cost() diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index 5bc3ee3bc58..2583c4cac2b 100644 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream(): client = OpenAIWrapper(config_list=config_list) response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -33,7 +33,7 @@ def test_chat_completion_stream(): client = OpenAIWrapper(config_list=config_list) response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -66,7 +66,7 @@ def test_chat_functions_stream(): stream=True, ) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -75,7 +75,7 @@ def test_completion_stream(): client = OpenAIWrapper(config_list=config_list) response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True) print(response) - print(client.extract_text_or_function_call(response)) + print(client.extract_text_or_completion_object(response)) if __name__ == "__main__": diff --git a/website/docs/Installation.md b/website/docs/Installation.md index 4f32824e5a5..e78c2807cdf 100644 --- a/website/docs/Installation.md +++ b/website/docs/Installation.md @@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`. from autogen import OpenAIWrapper client = OpenAIWrapper(config_list=config_list) response = client.create(messages=[{"role": "user", "content": "2+2="}]) -print(client.extract_text_or_function_call(response)) +print(client.extract_text_or_completion_object(response)) ``` - Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release. Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function). diff --git a/website/docs/Use-Cases/enhanced_inference.md b/website/docs/Use-Cases/enhanced_inference.md index f96599b13dd..1313713bc26 100644 --- a/website/docs/Use-Cases/enhanced_inference.md +++ b/website/docs/Use-Cases/enhanced_inference.md @@ -119,7 +119,7 @@ client = OpenAIWrapper() # ChatCompletion response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo") # extract the response text -print(client.extract_text_or_function_call(response)) +print(client.extract_text_or_completion_object(response)) # get cost of this completion print(response.cost) # Azure OpenAI endpoint @@ -127,7 +127,7 @@ client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azu # Completion response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct") # extract the response text -print(client.extract_text_or_function_call(response)) +print(client.extract_text_or_completion_object(response)) ``` @@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme ```python def valid_json_filter(response, **_): - for text in OpenAIWrapper.extract_text_or_function_call(response): + for text in OpenAIWrapper.extract_text_or_completion_object(response): try: json.loads(text) return True