Skip to content

Commit

Permalink
Enhance vertexai integration (#3086)
Browse files Browse the repository at this point in the history
* switch to officially supported Vertex AI message sending + safety setting converion for vertexai

* add system instructions

* switch to officially supported Vertex AI message sending + safety setting converion for vertexai

* fix bug in safety settings conversion

* add missing system instructions

* add safety settings to send message

* add support for credentials objects

* add type checkingchange project_id to project arg

* add more tests

* fix mock creation in test

* extend docstring

* fix errors with gemini message format in chats

* add option for vertexai response validation setting & improve docstring

* readding empty message handling

* add more tests

* extend and improve gemini vertexai jupyter notebook

* rename project arg to project_id and GOOGLE_API_KEY env var to GOOGLE_GEMINI_API_KEY

* adjust docstring formatting
  • Loading branch information
luxzoli authored Jul 23, 2024
1 parent 1daf852 commit a5e5be7
Show file tree
Hide file tree
Showing 9 changed files with 518 additions and 104 deletions.
11 changes: 10 additions & 1 deletion autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order


def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
else:
return str(carryover_item)


def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()

Expand All @@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
Expand Down
3 changes: 2 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from openai import BadRequestError

from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import model_dump
Expand Down Expand Up @@ -2364,7 +2365,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
Expand Down
99 changes: 82 additions & 17 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"config_list": [{
"api_type": "google",
"model": "gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY"),
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
Expand All @@ -32,6 +32,7 @@
from __future__ import annotations

import base64
import logging
import os
import random
import re
Expand All @@ -45,13 +46,19 @@
import vertexai
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

logger = logging.getLogger(__name__)


class GeminiClient:
Expand Down Expand Up @@ -81,29 +88,36 @@ def _initialize_vertexai(self, **params):
vertexai_init_args["project"] = params["project_id"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
if "credentials" in params:
assert isinstance(
params["credentials"], Credentials
), "Object type google.auth.credentials.Credentials is expected!"
vertexai_init_args["credentials"] = params["credentials"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)

def __init__(self, **kwargs):
"""Uses either either api_key for authentication from the LLM config
(specifying the GOOGLE_API_KEY environment variable also works),
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
where project_id and location can also be passed as parameters. Service account key file can also be used.
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
like in Google Cloud Shell.
Args:
api_key (str): The API key for using Gemini.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
location (str): Compute region to be used, like 'us-west1'.
This parameter is only valid in case no API key is specified.
This parameter is only valid in case no API key is specified.
"""
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GOOGLE_API_KEY")
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
if self.api_key is None:
self.use_vertexai = True
self._initialize_vertexai(**kwargs)
Expand Down Expand Up @@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
response_validation = params.get("response_validation", True)

generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
else:
safety_settings = params.get("safety_settings", {})

if stream:
warnings.warn(
Expand All @@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1], stream=stream)
response = chat.send_message(
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand All @@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
# response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
Expand Down Expand Up @@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
if content == "":
content = "empty" # Empty content is not allowed.
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
else:
Expand Down Expand Up @@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li

return rst

@staticmethod
def _to_vertexai_safety_settings(safety_settings):
"""Convert safety settings to VertexAI format if needed,
like when specifying them in the OAI_CONFIG_LIST
"""
if isinstance(safety_settings, list) and all(
[
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
for safety_setting in safety_settings
]
):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings
else:
return safety_settings


def _to_pil(data: str) -> Image.Image:
"""
Expand Down
10 changes: 9 additions & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from openai.types.beta.assistant import Assistant
from packaging.version import parse

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
NON_CACHE_KEY = [
"api_key",
"base_url",
"api_type",
"api_version",
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
Expand Down
11 changes: 11 additions & 0 deletions test/agentchat/test_chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import autogen
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats
from autogen.agentchat.chat import _post_process_carryover_item

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import reason, skip_openai # noqa: E402
Expand Down Expand Up @@ -620,6 +621,15 @@ def my_writing_task(sender, recipient, context):
print(chat_results[1].summary, chat_results[1].cost)


def test_post_process_carryover_item():
gemini_carryover_item = {"content": "How can I help you?", "role": "model"}
assert (
_post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"]
), "Incorrect carryover postprocessing"
carryover_item = "How can I help you?"
assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing"


if __name__ == "__main__":
test_chats()
# test_chats_general()
Expand All @@ -628,3 +638,4 @@ def my_writing_task(sender, recipient, context):
# test_chats_w_func()
# test_chat_messages_for_summary()
# test_udf_message_in_chats()
test_post_process_carryover_item()
56 changes: 56 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,58 @@ def sample_function():
)


def test_process_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"


def test_process_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


def test_handle_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant"
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"


def test_handle_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"

proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


if __name__ == "__main__":
# test_trigger()
# test_context()
Expand All @@ -1473,6 +1525,10 @@ def sample_function():
# test_max_turn()
# test_process_before_send()
# test_message_func()

test_summary()
test_adding_duplicate_function_warning()
# test_function_registration_e2e_sync()

test_process_gemini_carryover()
test_process_carryover()
Loading

0 comments on commit a5e5be7

Please sign in to comment.