Skip to content

Commit

Permalink
Convert ChatCompletionMessage to Dict after completion (#791)
Browse files Browse the repository at this point in the history
* update

* update

* update signature

* update

* update

* fix test funccall groupchat

* reverse change

* update

* update

* update

* update

* update

---------

Co-authored-by: Qingyun Wu <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2023
1 parent a31b240 commit 9cec541
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install -e .[mathchat,test]
pip install -e .[test]
pip uninstall -y openai
coverage run -a -m pytest test --ignore=test/agentchat/contrib
coverage xml
Expand Down
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/compressible_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def compress_messages(
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
return False, None

compressed_message = self.client.extract_text_or_function_call(response)[0]
compressed_message = self.client.extract_text_or_completion_object(response)[0]
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
if self.compress_config["verbose"]:
print(
Expand Down
7 changes: 6 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,12 @@ def generate_oai_reply(
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
)
return True, client.extract_text_or_function_call(response)[0]

# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = extracted_response.model_dump(mode="dict")
return True, extracted_response

async def a_generate_oai_reply(
self,
Expand Down
30 changes: 23 additions & 7 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token

TOOL_ENABLED = False
try:
import openai
from openai import OpenAI, APIError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache

if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
Expand Down Expand Up @@ -205,7 +209,7 @@ def create(self, **config):
```python
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
)
```
Expand Down Expand Up @@ -442,21 +446,33 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
def extract_text_or_completion_object(
cls, response: ChatCompletion | Completion
) -> Union[List[str], List[ChatCompletionMessage]]:
"""Extract the text or ChatCompletion objects from a completion or chat response.
Args:
response (ChatCompletion | Completion): The response from openai.
Returns:
A list of text or function calls in the responses.
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
"""
choices = response.choices
if isinstance(response, Completion):
return [choice.text for choice in choices]
return [
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
]

if TOOL_ENABLED:
return [
choice.message
if choice.message.function_call is not None or choice.message.tool_calls is not None
else choice.message.content
for choice in choices
]
else:
return [
choice.message if choice.message.function_call is not None else choice.message.content
for choice in choices
]


# TODO: logging
2 changes: 1 addition & 1 deletion test/agentchat/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_eval_math_responses():
functions=functions,
)
print(response)
responses = client.extract_text_or_function_call(response)
responses = client.extract_text_or_completion_object(response)
print(responses[0])
function_call = responses[0].function_call
name, arguments = function_call.name, json.loads(function_call.arguments)
Expand Down
50 changes: 47 additions & 3 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from test_utils import OAI_CONFIG_LIST, KEY_LOC

TOOL_ENABLED = False
try:
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletionMessage
except ImportError:
skip = True
else:
skip = False
import openai

if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -24,7 +30,44 @@ def test_aoai_chat_completion():
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
def test_oai_tool_calling_extraction():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[
{
"role": "user",
"content": "What is the weather in San Francisco?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "getCurrentWeather",
"description": "Get the weather in location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
},
}
],
)
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -36,7 +79,7 @@ def test_chat_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -45,7 +88,7 @@ def test_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand Down Expand Up @@ -96,6 +139,7 @@ def test_usage_summary():

if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
test_completion()
test_cost()
Expand Down
8 changes: 4 additions & 4 deletions test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -33,7 +33,7 @@ def test_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_chat_functions_stream():
stream=True,
)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -75,7 +75,7 @@ def test_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion website/docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
from autogen import OpenAIWrapper
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}])
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
```
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).
Expand Down
6 changes: 3 additions & 3 deletions website/docs/Use-Cases/enhanced_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ client = OpenAIWrapper()
# ChatCompletion
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
# get cost of this completion
print(response.cost)
# Azure OpenAI endpoint
client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azure")
# Completion
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))

```

Expand Down Expand Up @@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme

```python
def valid_json_filter(response, **_):
for text in OpenAIWrapper.extract_text_or_function_call(response):
for text in OpenAIWrapper.extract_text_or_completion_object(response):
try:
json.loads(text)
return True
Expand Down

0 comments on commit 9cec541

Please sign in to comment.