From 7e38951687a2bafa71cd19cf7b7e76407840cf68 Mon Sep 17 00:00:00 2001 From: yeounoh Date: Wed, 24 Jul 2024 10:48:52 -0700 Subject: [PATCH 1/6] Refactoring GeminiClient. --- autogen/oai/gemini.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 33790c9851c..b00636a1364 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -57,6 +57,7 @@ from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import Part as VertexAIPart from vertexai.generative_models import SafetySetting as VertexAISafetySetting +from vertexai.preview import caching logger = logging.getLogger(__name__) @@ -129,6 +130,11 @@ def __init__(self, **kwargs): assert ("project_id" not in kwargs) and ( "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" + genai.configure(api_key=self.api_key) + # Update the following aliases so that it calls generative AI SDK. + # Vertex AI SDK offers extra features that should be gated by `use_vertexai`. + GenerativeModel = genai.GenerativeModel + caching = genai.caching def message_retrieval(self, response) -> List: """ @@ -198,23 +204,18 @@ def create(self, params: Dict) -> ChatCompletion: if "vision" not in model_name: # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) - if self.use_vertexai: - model = GenerativeModel( + model = GenerativeModel( model_name, generation_config=generation_config, safety_settings=safety_settings, system_instruction=system_instruction, ) + if self.use_vertexai: + # `response_validation=True` (default) sanitizes the chat history by logging + # only valid and complete messages. Blocked messages should be excluded to keep + # the chat session state usable. This is only available in Vertex AI SDK. chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) else: - # we use chat model by default - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - genai.configure(api_key=self.api_key) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): @@ -258,7 +259,6 @@ def create(self, params: Dict) -> ChatCompletion: safety_settings=safety_settings, system_instruction=system_instruction, ) - genai.configure(api_key=self.api_key) # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1].parts) From 49df9110f34ecb5e8c7af5b78669c6d3affd0eed Mon Sep 17 00:00:00 2001 From: yeounoh Date: Wed, 24 Jul 2024 13:31:03 -0700 Subject: [PATCH 2/6] Add GeminiContextCache --- autogen/oai/gemini.py | 173 +++++++++++++++++++++++++++++++++--------- 1 file changed, 137 insertions(+), 36 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index b00636a1364..cfaf550c1ed 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -32,6 +32,7 @@ from __future__ import annotations import base64 +import datetime import logging import os import random @@ -42,6 +43,7 @@ from typing import Any, Dict, List, Mapping, Union import google.generativeai as genai +from google.generativeai import protos import requests import vertexai from google.ai.generativelanguage import Content, Part @@ -83,7 +85,8 @@ class GeminiClient: def _initialize_vertexai(self, **params): if "google_application_credentials" in params: # Path to JSON Keyfile - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"] + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params[ + "google_application_credentials"] vertexai_init_args = {} if "project_id" in params: vertexai_init_args["project"] = params["project_id"] @@ -131,10 +134,6 @@ def __init__(self, **kwargs): "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" genai.configure(api_key=self.api_key) - # Update the following aliases so that it calls generative AI SDK. - # Vertex AI SDK offers extra features that should be gated by `use_vertexai`. - GenerativeModel = genai.GenerativeModel - caching = genai.caching def message_retrieval(self, response) -> List: """ @@ -188,7 +187,8 @@ def create(self, params: Dict) -> ChatCompletion: if autogen_term in params } if self.use_vertexai: - safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {})) + safety_settings = GeminiClient._to_vertexai_safety_settings( + params.get("safety_settings", {})) else: safety_settings = params.get("safety_settings", {}) @@ -199,31 +199,42 @@ def create(self, params: Dict) -> ChatCompletion: ) if n_response > 1: - warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning) + warnings.warn( + "Gemini only supports `n=1` for now. We only generate one response.", + UserWarning) if "vision" not in model_name: # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) - model = GenerativeModel( + if self.use_vertexai: + model = GenerativeModel( model_name, generation_config=generation_config, safety_settings=safety_settings, system_instruction=system_instruction, ) - if self.use_vertexai: # `response_validation=True` (default) sanitizes the chat history by logging # only valid and complete messages. Blocked messages should be excluded to keep # the chat session state usable. This is only available in Vertex AI SDK. - chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) + chat = model.start_chat( + history=gemini_messages[:-1], + response_validation=response_validation) else: + model = genai.GenerativeModel( + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, + ) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): ans = None try: response = chat.send_message( - gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings - ) + gemini_messages[-1].parts, + stream=stream, + safety_settings=safety_settings) except InternalServerError: delay = 5 * (2**attempt) warnings.warn( @@ -232,14 +243,18 @@ def create(self, params: Dict) -> ChatCompletion: ) time.sleep(delay) except Exception as e: - raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}") + raise RuntimeError( + f"Google GenAI exception occurred while calling Gemini API: {e}" + ) else: # `ans = response.text` is unstable. Use the following code instead. ans: str = chat.history[-1].parts[0].text break if ans is None: - raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.") + raise RuntimeError( + f"Fail to get response from Google AI after retrying {attempt + 1} times." + ) prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens completion_tokens = model.count_tokens(ans).total_tokens @@ -262,7 +277,8 @@ def create(self, params: Dict) -> ChatCompletion: # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1].parts) - user_message = self._oai_content_to_gemini_content(messages[-1]["content"]) + user_message = self._oai_content_to_gemini_content( + messages[-1]["content"]) if len(messages) > 2: warnings.warn( "Warning: Gemini's vision model does not support chat history yet.", @@ -281,7 +297,10 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens = model.count_tokens(ans).total_tokens # 3. convert output - message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None) + message = ChatCompletionMessage(role="assistant", + content=ans, + function_call=None, + tool_calls=None) choices = [Choice(finish_reason="stop", index=0, message=message)] response_oai = ChatCompletion( @@ -295,12 +314,14 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name), + cost=calculate_gemini_cost(prompt_tokens, completion_tokens, + model_name), ) return response_oai - def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: + def _oai_content_to_gemini_content(self, content: Union[str, + List]) -> List: """Convert content from OAI format to Gemini format""" rst = [] if isinstance(content, str): @@ -328,14 +349,16 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: re.match(r"data:image/(?:png|jpeg);base64,", img_url) img = get_image_data(img_url, use_b64=False) # image/png works with jpeg as well - img_part = VertexAIPart.from_data(img, mime_type="image/png") + img_part = VertexAIPart.from_data( + img, mime_type="image/png") rst.append(img_part) else: b64_img = get_image_data(msg["image_url"]["url"]) img = _to_pil(b64_img) rst.append(img) else: - raise ValueError(f"Unsupported message type: {msg['type']}") + raise ValueError( + f"Unsupported message type: {msg['type']}") else: raise ValueError(f"Unsupported message type: {type(msg)}") return rst @@ -353,7 +376,8 @@ def _concat_parts(self, parts: List[Part]) -> List: for current_part in parts[1:]: if previous_part.text != "": if self.use_vertexai: - previous_part = VertexAIPart.from_text(previous_part.text + current_part.text) + previous_part = VertexAIPart.from_text(previous_part.text + + current_part.text) else: previous_part.text += current_part.text else: @@ -369,7 +393,8 @@ def _concat_parts(self, parts: List[Part]) -> List: return concatenated_parts - def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + def _oai_messages_to_gemini_messages( + self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Gemini format. Make sure the "user" role and "model" role are interleaved. Also, make sure the last item is from the "user" role. @@ -384,7 +409,8 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li curr_parts += parts elif role != prev_role: if self.use_vertexai: - rst.append(VertexAIContent(parts=curr_parts, role=prev_role)) + rst.append( + VertexAIContent(parts=curr_parts, role=prev_role)) else: rst.append(Content(parts=curr_parts, role=prev_role)) curr_parts = parts @@ -402,9 +428,15 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li # We add a dummy message "continue" if the last role is not the user. if rst[-1].role != "user": if self.use_vertexai: - rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user")) + rst.append( + VertexAIContent( + parts=self._oai_content_to_gemini_content("continue"), + role="user")) else: - rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user")) + rst.append( + Content( + parts=self._oai_content_to_gemini_content("continue"), + role="user")) return rst @@ -413,20 +445,24 @@ def _to_vertexai_safety_settings(safety_settings): """Convert safety settings to VertexAI format if needed, like when specifying them in the OAI_CONFIG_LIST """ - if isinstance(safety_settings, list) and all( - [ - isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting) + if isinstance(safety_settings, list) and all([ + isinstance(safety_setting, dict) + and not isinstance(safety_setting, VertexAISafetySetting) for safety_setting in safety_settings - ] - ): + ]): vertexai_safety_settings = [] for safety_setting in safety_settings: - if safety_setting["category"] not in VertexAIHarmCategory.__members__: + if safety_setting[ + "category"] not in VertexAIHarmCategory.__members__: invalid_category = safety_setting["category"] - logger.error(f"Safety setting category {invalid_category} is invalid") - elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__: + logger.error( + f"Safety setting category {invalid_category} is invalid" + ) + elif safety_setting[ + "threshold"] not in VertexAIHarmBlockThreshold.__members__: invalid_threshold = safety_setting["threshold"] - logger.error(f"Safety threshold {invalid_threshold} is invalid") + logger.error( + f"Safety threshold {invalid_threshold} is invalid") else: vertexai_safety_setting = VertexAISafetySetting( category=safety_setting["category"], @@ -438,6 +474,68 @@ def _to_vertexai_safety_settings(safety_settings): return safety_settings +class GeminiContextCache: + """ + Context cache for Gemini models. Context cache helps reduce the cost by caching + the same input tokens that are used repeatedly. A cache instance is created using + a publisher model and the model name is immutable once the cache is created. + The created cache has TTL (1 hour by default) and this can be updated after the creation. + The cost for caching depends on the input token size and how long you want the tokens to persist. + Context cache is available in Gemini 1.5. + """ + + def __init__(self, model: str, display_name: str, system_instruction: str, + contents: list[str], ttl: datetime.timedelta, use_vertexai=True): + self.use_vertexai = use_vertexai + _caching = caching if use_vertexai else genai.caching + self.cache = _caching.CachedContent.create( + model=model, + display_name=display_name, + system_instruction=system_instruction, + contents=contents, + ttl=ttl) + + def is_compatible(self, model: Union[GenerativeModel, genai.GenerativeModel]) -> bool: + """ + Verify if this cache is compatible with a given model. + """ + # Context cache is available in gemini 1.5 stable versions. + if re.match(r"^gemini-1\.5-(pro|flash)-\d{3}$", model._model_name): + if ((self.use_vertexai and isinstance(model, GenerativeModel)) + or (not self.use_vertexai + and isinstance(model, genai.GenerativeModel))): + return True + warnings.warn( + "Cache was created using a different SDK than the model: " + f"use_vertexai={self.use_vertexai}, type(model)={type(model)}") + return False + + def update_ttl(self, ttl: datetime.timedelta): + self.cache.update(ttl=ttl) + + def delete(self): + self.cache.delete() + + @property + def model(self) -> str: + return self.cache._proto.model + + @property + def display_name(self) -> str: + return self.cache._proto.display_name + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self.cache._proto.usage_metadata + + @property + def expire_time(self) -> datetime.datetime: + return self.cache.expire_time() + + def __str__(self): + return self.cache.__str__() + + def _to_pil(data: str) -> Image.Image: """ Converts a base64 encoded image data string to a PIL Image object. @@ -472,14 +570,17 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: return content -def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: +def calculate_gemini_cost(input_tokens: int, output_tokens: int, + model_name: str) -> float: if "1.5" in model_name or "gemini-experimental" in model_name: # "gemini-1.5-pro-preview-0409" # Cost is $7 per million input tokens and $21 per million output tokens return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: - warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) + warnings.warn( + f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", + UserWarning) # Cost is $0.5 per million input tokens and $1.5 per million output tokens return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 From 3207461e68bd1e219646f665d95f95ada69a3983 Mon Sep 17 00:00:00 2001 From: yeounoh Date: Wed, 24 Jul 2024 15:25:21 -0700 Subject: [PATCH 3/6] Create gemini model with context caching --- autogen/oai/gemini.py | 71 ++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index cfaf550c1ed..eb6a9222340 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -134,6 +134,7 @@ def __init__(self, **kwargs): "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" genai.configure(api_key=self.api_key) + self.context_cache = None def message_retrieval(self, response) -> List: """ @@ -180,6 +181,8 @@ def create(self, params: Dict) -> ChatCompletion: n_response = params.get("n", 1) system_instruction = params.get("system_instruction", None) response_validation = params.get("response_validation", True) + context_cache = params.get('context_cache', None) + self.context_cache = context_cache # Keep the cache reference used at the creation time generation_config = { gemini_term: params[autogen_term] @@ -203,16 +206,22 @@ def create(self, params: Dict) -> ChatCompletion: "Gemini only supports `n=1` for now. We only generate one response.", UserWarning) + gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel + if context_cache: + model = gen_model_cls( + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, + ) + else: + # Context prefix caching can help reduce the cost. + model = gen_model_cls.from_cached_content(cached_content=context_cache) + if "vision" not in model_name: # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) # `response_validation=True` (default) sanitizes the chat history by logging # only valid and complete messages. Blocked messages should be excluded to keep # the chat session state usable. This is only available in Vertex AI SDK. @@ -220,12 +229,6 @@ def create(self, params: Dict) -> ChatCompletion: history=gemini_messages[:-1], response_validation=response_validation) else: - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): @@ -259,21 +262,7 @@ def create(self, params: Dict) -> ChatCompletion: prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens completion_tokens = model.count_tokens(ans).total_tokens elif model_name == "gemini-pro-vision": - # B. handle the vision model - if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - else: - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) + # B. handle the vision model. # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1].parts) @@ -302,6 +291,9 @@ def create(self, params: Dict) -> ChatCompletion: function_call=None, tool_calls=None) choices = [Choice(finish_reason="stop", index=0, message=message)] + context_cache_tokens = int( + self.context_cache.usage_metadata.total_token_count if self. + context_cache else 0) response_oai = ChatCompletion( id=str(random.randint(0, 1000)), @@ -314,7 +306,8 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - cost=calculate_gemini_cost(prompt_tokens, completion_tokens, + cost=calculate_gemini_cost(prompt_tokens - context_cache_tokens, + completion_tokens, context_cache_tokens, model_name), ) @@ -476,8 +469,11 @@ def _to_vertexai_safety_settings(safety_settings): class GeminiContextCache: """ - Context cache for Gemini models. Context cache helps reduce the cost by caching - the same input tokens that are used repeatedly. A cache instance is created using + Context cache for Gemini models. The semantics of this cache operation is different + from the generic autogen.cache, where the input prompt and the agent outputs are cached. + Here, context cache stores the common prefix tokens to Gemini models. + + Context cache helps reduce the cost by caching the same input tokens that are used repeatedly. A cache instance is created using a publisher model and the model name is immutable once the cache is created. The created cache has TTL (1 hour by default) and this can be updated after the creation. The cost for caching depends on the input token size and how long you want the tokens to persist. @@ -518,15 +514,19 @@ def delete(self): @property def model(self) -> str: - return self.cache._proto.model + return self.cache.model() + + @property + def name(self) -> str: + return self.cache.name() @property def display_name(self) -> str: - return self.cache._proto.display_name + return self.cache.display_name() @property def usage_metadata(self) -> protos.CachedContent.UsageMetadata: - return self.cache._proto.usage_metadata + return self.cache.usage_metadata() @property def expire_time(self) -> datetime.datetime: @@ -571,11 +571,12 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: def calculate_gemini_cost(input_tokens: int, output_tokens: int, - model_name: str) -> float: + context_cache_tokens: int, model_name: str) -> float: + # TODO(yeounoh) - update the pricing model to reflect the prompt size if "1.5" in model_name or "gemini-experimental" in model_name: # "gemini-1.5-pro-preview-0409" # Cost is $7 per million input tokens and $21 per million output tokens - return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + 1.75 * context_cache_tokens / 1e6 if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: warnings.warn( From 491b67ecd335f8591c26301d6f6aecacd6323174 Mon Sep 17 00:00:00 2001 From: yeounoh Date: Wed, 24 Jul 2024 15:55:45 -0700 Subject: [PATCH 4/6] Add a unit test for cost calculation --- autogen/oai/gemini.py | 1 + test/oai/test_gemini.py | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index eb6a9222340..0f077e9b11c 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -146,6 +146,7 @@ def message_retrieval(self, response) -> List: return [choice.message for choice in response.choices] def cost(self, response) -> float: + # TODO(yeounoh) should use cost calculation function. return response.cost @staticmethod diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 61fdbe6d735..592054477fc 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -268,15 +268,36 @@ def test_internal_server_error_retry(mock_genai, gemini_client): # Test cost calculation @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") def test_cost_calculation(gemini_client, mock_response): + # TODO(yeounoh) - update the test case so that it is more meaningful. response = mock_response( text="Example response", choices=[{"message": "Test message 1"}], usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - cost=0.01, + cost=0.000175, model="gemini-pro", ) assert gemini_client.cost(response) > 0, "Cost should be correctly calculated as zero" + response_with_cache = mock_response( + text="Example response", + choices=[{ + "message": "Test message 1" + }], + usage={ + # openai usage stats do not reflect gemini context caching. + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + # context_cache_tokens should offset prompt_tokens and reduce the + # total cost durign the cost calculation. + "context_cache_tokens": 3 + }, + cost=0.00015925, + model="gemini-pro", + ) + assert gemini_client.cost(response) > gemini_client.cost(response_with_cache), \ + "Context caching should reduce the cost." + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") From d38df67f44ef20c02a73c5c56e3fbc26381766f8 Mon Sep 17 00:00:00 2001 From: yeounoh Date: Thu, 25 Jul 2024 00:15:29 -0700 Subject: [PATCH 5/6] Add another unit test --- autogen/oai/gemini.py | 7 +++--- test/oai/test_gemini.py | 49 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 0f077e9b11c..a0ee8ffa044 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -209,15 +209,16 @@ def create(self, params: Dict) -> ChatCompletion: gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel if context_cache: + # Context prefix caching can help reduce the cost. + model = gen_model_cls.from_cached_content(cached_content=context_cache) + else: model = gen_model_cls( model_name, generation_config=generation_config, safety_settings=safety_settings, system_instruction=system_instruction, ) - else: - # Context prefix caching can help reduce the cost. - model = gen_model_cls.from_cached_content(cached_content=context_cache) + if "vision" not in model_name: # A. create and call the chat model. diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 592054477fc..83eb528e759 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -1,6 +1,7 @@ import os from unittest.mock import MagicMock, patch +from autogen.oai.gemini import calculate_gemini_cost import pytest try: @@ -13,6 +14,7 @@ from vertexai.generative_models import SafetySetting as VertexAISafetySetting from autogen.oai.gemini import GeminiClient + from autogen.oai.gemini import GeminiContextCacheß skip = False except ImportError: @@ -382,6 +384,53 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, # Assertions to check if response is structured as expected assert response.choices[0].message.content == "Example response", "Response content should match expected output" +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +@patch("autogen.oai.gemini.GenerativeModel") +@patch("autogen.oai.gemini.vertexai.init") +def test_vertexai_default_auth_create_response_with_context_cache(mock_init, mock_generative_model, gemini_google_auth_default_client): + # Mock the genai model configuration and creation process + mock_chat = MagicMock() + mock_model = MagicMock() + mock_init.return_value = None + mock_generative_model.return_value = mock_model + mock_model.start_chat.return_value = mock_chat + + # Set up a mock for the chat history item access and the text attribute return + mock_history_part = MagicMock() + mock_history_part.text = "Example response" + mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + + # Setup the mock to return a mocked chat response + mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + + # Setup the mock to return a mocked cache usage + mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count = 10)) + + # Call the create method + response = gemini_google_auth_default_client.create( + {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} + ) + response_with_cache = gemini_google_auth_default_client.create({ + "model": + "gemini-1.5-pro-001", + "context_cache": + mock_context_cache, + "messages": [{ + "content": "Hello", + "role": "user" + }], + "stream": + False + }) + + # Assertions to check if response is structured as expected + assert response_with_cache.choices[ + 0].message.content == "Example response", "Response content should match expected output" + assert gemini_google_auth_default_client.cost( + response) > gemini_google_auth_default_client.cost( + response_with_cache + ), "Context caching should result in reduced cost." + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") From aef423176111d23246e574c80a24e2d988b96cf3 Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Thu, 25 Jul 2024 21:56:05 +0000 Subject: [PATCH 6/6] Linting --- autogen/oai/gemini.py | 134 +++++++++++++++------------------------- test/oai/test_gemini.py | 56 ++++++++--------- 2 files changed, 78 insertions(+), 112 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index a0ee8ffa044..14bfd931715 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -43,12 +43,12 @@ from typing import Any, Dict, List, Mapping, Union import google.generativeai as genai -from google.generativeai import protos import requests import vertexai from google.ai.generativelanguage import Content, Part from google.api_core.exceptions import InternalServerError from google.auth.credentials import Credentials +from google.generativeai import protos from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage @@ -85,8 +85,7 @@ class GeminiClient: def _initialize_vertexai(self, **params): if "google_application_credentials" in params: # Path to JSON Keyfile - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params[ - "google_application_credentials"] + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"] vertexai_init_args = {} if "project_id" in params: vertexai_init_args["project"] = params["project_id"] @@ -182,7 +181,7 @@ def create(self, params: Dict) -> ChatCompletion: n_response = params.get("n", 1) system_instruction = params.get("system_instruction", None) response_validation = params.get("response_validation", True) - context_cache = params.get('context_cache', None) + context_cache = params.get("context_cache", None) self.context_cache = context_cache # Keep the cache reference used at the creation time generation_config = { @@ -191,8 +190,7 @@ def create(self, params: Dict) -> ChatCompletion: if autogen_term in params } if self.use_vertexai: - safety_settings = GeminiClient._to_vertexai_safety_settings( - params.get("safety_settings", {})) + safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {})) else: safety_settings = params.get("safety_settings", {}) @@ -203,9 +201,7 @@ def create(self, params: Dict) -> ChatCompletion: ) if n_response > 1: - warnings.warn( - "Gemini only supports `n=1` for now. We only generate one response.", - UserWarning) + warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning) gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel if context_cache: @@ -219,7 +215,6 @@ def create(self, params: Dict) -> ChatCompletion: system_instruction=system_instruction, ) - if "vision" not in model_name: # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) @@ -227,9 +222,7 @@ def create(self, params: Dict) -> ChatCompletion: # `response_validation=True` (default) sanitizes the chat history by logging # only valid and complete messages. Blocked messages should be excluded to keep # the chat session state usable. This is only available in Vertex AI SDK. - chat = model.start_chat( - history=gemini_messages[:-1], - response_validation=response_validation) + chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) else: chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 @@ -237,9 +230,8 @@ def create(self, params: Dict) -> ChatCompletion: ans = None try: response = chat.send_message( - gemini_messages[-1].parts, - stream=stream, - safety_settings=safety_settings) + gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings + ) except InternalServerError: delay = 5 * (2**attempt) warnings.warn( @@ -248,18 +240,14 @@ def create(self, params: Dict) -> ChatCompletion: ) time.sleep(delay) except Exception as e: - raise RuntimeError( - f"Google GenAI exception occurred while calling Gemini API: {e}" - ) + raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}") else: # `ans = response.text` is unstable. Use the following code instead. ans: str = chat.history[-1].parts[0].text break if ans is None: - raise RuntimeError( - f"Fail to get response from Google AI after retrying {attempt + 1} times." - ) + raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.") prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens completion_tokens = model.count_tokens(ans).total_tokens @@ -268,8 +256,7 @@ def create(self, params: Dict) -> ChatCompletion: # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1].parts) - user_message = self._oai_content_to_gemini_content( - messages[-1]["content"]) + user_message = self._oai_content_to_gemini_content(messages[-1]["content"]) if len(messages) > 2: warnings.warn( "Warning: Gemini's vision model does not support chat history yet.", @@ -288,14 +275,9 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens = model.count_tokens(ans).total_tokens # 3. convert output - message = ChatCompletionMessage(role="assistant", - content=ans, - function_call=None, - tool_calls=None) + message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None) choices = [Choice(finish_reason="stop", index=0, message=message)] - context_cache_tokens = int( - self.context_cache.usage_metadata.total_token_count if self. - context_cache else 0) + context_cache_tokens = int(self.context_cache.usage_metadata.total_token_count if self.context_cache else 0) response_oai = ChatCompletion( id=str(random.randint(0, 1000)), @@ -308,15 +290,14 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - cost=calculate_gemini_cost(prompt_tokens - context_cache_tokens, - completion_tokens, context_cache_tokens, - model_name), + cost=calculate_gemini_cost( + prompt_tokens - context_cache_tokens, completion_tokens, context_cache_tokens, model_name + ), ) return response_oai - def _oai_content_to_gemini_content(self, content: Union[str, - List]) -> List: + def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: """Convert content from OAI format to Gemini format""" rst = [] if isinstance(content, str): @@ -344,16 +325,14 @@ def _oai_content_to_gemini_content(self, content: Union[str, re.match(r"data:image/(?:png|jpeg);base64,", img_url) img = get_image_data(img_url, use_b64=False) # image/png works with jpeg as well - img_part = VertexAIPart.from_data( - img, mime_type="image/png") + img_part = VertexAIPart.from_data(img, mime_type="image/png") rst.append(img_part) else: b64_img = get_image_data(msg["image_url"]["url"]) img = _to_pil(b64_img) rst.append(img) else: - raise ValueError( - f"Unsupported message type: {msg['type']}") + raise ValueError(f"Unsupported message type: {msg['type']}") else: raise ValueError(f"Unsupported message type: {type(msg)}") return rst @@ -371,8 +350,7 @@ def _concat_parts(self, parts: List[Part]) -> List: for current_part in parts[1:]: if previous_part.text != "": if self.use_vertexai: - previous_part = VertexAIPart.from_text(previous_part.text + - current_part.text) + previous_part = VertexAIPart.from_text(previous_part.text + current_part.text) else: previous_part.text += current_part.text else: @@ -388,8 +366,7 @@ def _concat_parts(self, parts: List[Part]) -> List: return concatenated_parts - def _oai_messages_to_gemini_messages( - self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Gemini format. Make sure the "user" role and "model" role are interleaved. Also, make sure the last item is from the "user" role. @@ -404,8 +381,7 @@ def _oai_messages_to_gemini_messages( curr_parts += parts elif role != prev_role: if self.use_vertexai: - rst.append( - VertexAIContent(parts=curr_parts, role=prev_role)) + rst.append(VertexAIContent(parts=curr_parts, role=prev_role)) else: rst.append(Content(parts=curr_parts, role=prev_role)) curr_parts = parts @@ -423,15 +399,9 @@ def _oai_messages_to_gemini_messages( # We add a dummy message "continue" if the last role is not the user. if rst[-1].role != "user": if self.use_vertexai: - rst.append( - VertexAIContent( - parts=self._oai_content_to_gemini_content("continue"), - role="user")) + rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user")) else: - rst.append( - Content( - parts=self._oai_content_to_gemini_content("continue"), - role="user")) + rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user")) return rst @@ -440,24 +410,20 @@ def _to_vertexai_safety_settings(safety_settings): """Convert safety settings to VertexAI format if needed, like when specifying them in the OAI_CONFIG_LIST """ - if isinstance(safety_settings, list) and all([ - isinstance(safety_setting, dict) - and not isinstance(safety_setting, VertexAISafetySetting) + if isinstance(safety_settings, list) and all( + [ + isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting) for safety_setting in safety_settings - ]): + ] + ): vertexai_safety_settings = [] for safety_setting in safety_settings: - if safety_setting[ - "category"] not in VertexAIHarmCategory.__members__: + if safety_setting["category"] not in VertexAIHarmCategory.__members__: invalid_category = safety_setting["category"] - logger.error( - f"Safety setting category {invalid_category} is invalid" - ) - elif safety_setting[ - "threshold"] not in VertexAIHarmBlockThreshold.__members__: + logger.error(f"Safety setting category {invalid_category} is invalid") + elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__: invalid_threshold = safety_setting["threshold"] - logger.error( - f"Safety threshold {invalid_threshold} is invalid") + logger.error(f"Safety threshold {invalid_threshold} is invalid") else: vertexai_safety_setting = VertexAISafetySetting( category=safety_setting["category"], @@ -482,16 +448,20 @@ class GeminiContextCache: Context cache is available in Gemini 1.5. """ - def __init__(self, model: str, display_name: str, system_instruction: str, - contents: list[str], ttl: datetime.timedelta, use_vertexai=True): + def __init__( + self, + model: str, + display_name: str, + system_instruction: str, + contents: list[str], + ttl: datetime.timedelta, + use_vertexai=True, + ): self.use_vertexai = use_vertexai _caching = caching if use_vertexai else genai.caching self.cache = _caching.CachedContent.create( - model=model, - display_name=display_name, - system_instruction=system_instruction, - contents=contents, - ttl=ttl) + model=model, display_name=display_name, system_instruction=system_instruction, contents=contents, ttl=ttl + ) def is_compatible(self, model: Union[GenerativeModel, genai.GenerativeModel]) -> bool: """ @@ -499,13 +469,14 @@ def is_compatible(self, model: Union[GenerativeModel, genai.GenerativeModel]) -> """ # Context cache is available in gemini 1.5 stable versions. if re.match(r"^gemini-1\.5-(pro|flash)-\d{3}$", model._model_name): - if ((self.use_vertexai and isinstance(model, GenerativeModel)) - or (not self.use_vertexai - and isinstance(model, genai.GenerativeModel))): + if (self.use_vertexai and isinstance(model, GenerativeModel)) or ( + not self.use_vertexai and isinstance(model, genai.GenerativeModel) + ): return True warnings.warn( "Cache was created using a different SDK than the model: " - f"use_vertexai={self.use_vertexai}, type(model)={type(model)}") + f"use_vertexai={self.use_vertexai}, type(model)={type(model)}" + ) return False def update_ttl(self, ttl: datetime.timedelta): @@ -572,8 +543,7 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: return content -def calculate_gemini_cost(input_tokens: int, output_tokens: int, - context_cache_tokens: int, model_name: str) -> float: +def calculate_gemini_cost(input_tokens: int, output_tokens: int, context_cache_tokens: int, model_name: str) -> float: # TODO(yeounoh) - update the pricing model to reflect the prompt size if "1.5" in model_name or "gemini-experimental" in model_name: # "gemini-1.5-pro-preview-0409" @@ -581,9 +551,7 @@ def calculate_gemini_cost(input_tokens: int, output_tokens: int, return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + 1.75 * context_cache_tokens / 1e6 if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: - warnings.warn( - f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", - UserWarning) + warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) # Cost is $0.5 per million input tokens and $1.5 per million output tokens return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 83eb528e759..3c362b4a599 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -1,9 +1,10 @@ import os from unittest.mock import MagicMock, patch -from autogen.oai.gemini import calculate_gemini_cost import pytest +from autogen.oai.gemini import calculate_gemini_cost + try: import google.auth from google.api_core.exceptions import InternalServerError @@ -13,8 +14,7 @@ from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import SafetySetting as VertexAISafetySetting - from autogen.oai.gemini import GeminiClient - from autogen.oai.gemini import GeminiContextCacheß + from autogen.oai.gemini import GeminiClient, GeminiContextCacheß skip = False except ImportError: @@ -282,9 +282,7 @@ def test_cost_calculation(gemini_client, mock_response): response_with_cache = mock_response( text="Example response", - choices=[{ - "message": "Test message 1" - }], + choices=[{"message": "Test message 1"}], usage={ # openai usage stats do not reflect gemini context caching. "prompt_tokens": 10, @@ -292,13 +290,14 @@ def test_cost_calculation(gemini_client, mock_response): "total_tokens": 15, # context_cache_tokens should offset prompt_tokens and reduce the # total cost durign the cost calculation. - "context_cache_tokens": 3 + "context_cache_tokens": 3, }, cost=0.00015925, model="gemini-pro", ) - assert gemini_client.cost(response) > gemini_client.cost(response_with_cache), \ - "Context caching should reduce the cost." + assert gemini_client.cost(response) > gemini_client.cost( + response_with_cache + ), "Context caching should reduce the cost." @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @@ -384,10 +383,13 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, # Assertions to check if response is structured as expected assert response.choices[0].message.content == "Example response", "Response content should match expected output" + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.GenerativeModel") @patch("autogen.oai.gemini.vertexai.init") -def test_vertexai_default_auth_create_response_with_context_cache(mock_init, mock_generative_model, gemini_google_auth_default_client): +def test_vertexai_default_auth_create_response_with_context_cache( + mock_init, mock_generative_model, gemini_google_auth_default_client +): # Mock the genai model configuration and creation process mock_chat = MagicMock() mock_model = MagicMock() @@ -404,32 +406,28 @@ def test_vertexai_default_auth_create_response_with_context_cache(mock_init, moc mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) # Setup the mock to return a mocked cache usage - mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count = 10)) + mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count=10)) # Call the create method response = gemini_google_auth_default_client.create( {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} ) - response_with_cache = gemini_google_auth_default_client.create({ - "model": - "gemini-1.5-pro-001", - "context_cache": - mock_context_cache, - "messages": [{ - "content": "Hello", - "role": "user" - }], - "stream": - False - }) + response_with_cache = gemini_google_auth_default_client.create( + { + "model": "gemini-1.5-pro-001", + "context_cache": mock_context_cache, + "messages": [{"content": "Hello", "role": "user"}], + "stream": False, + } + ) # Assertions to check if response is structured as expected - assert response_with_cache.choices[ - 0].message.content == "Example response", "Response content should match expected output" - assert gemini_google_auth_default_client.cost( - response) > gemini_google_auth_default_client.cost( - response_with_cache - ), "Context caching should result in reduced cost." + assert ( + response_with_cache.choices[0].message.content == "Example response" + ), "Response content should match expected output" + assert gemini_google_auth_default_client.cost(response) > gemini_google_auth_default_client.cost( + response_with_cache + ), "Context caching should result in reduced cost." @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")