diff --git a/cookbook/providers/vertexai/README.md b/cookbook/providers/vertexai/README.md index 7bd98eef1..06674b896 100644 --- a/cookbook/providers/vertexai/README.md +++ b/cookbook/providers/vertexai/README.md @@ -1,4 +1,4 @@ -# Google Gemini Cookbook +# VertexAI Gemini Cookbook > Note: Fork and clone this repository if needed @@ -9,16 +9,14 @@ python3 -m venv ~/.venvs/aienv source ~/.venvs/aienv/bin/activate ``` -### 2. Export `GOOGLE_API_KEY` +### 2. Authenticate with Google Cloud -```shell -export GOOGLE_API_KEY=*** -``` +[Authenticate with Gcloud](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal) ### 3. Install libraries ```shell -pip install -U google-generativeai duckduckgo-search yfinance phidata +pip install -U google-cloud-aiplatform duckduckgo-search yfinance phidata ``` ### 4. Run Agent without Tools diff --git a/cookbook/providers/vertexai/agent_stream.py b/cookbook/providers/vertexai/agent_stream.py index 041777b31..2bf333b58 100644 --- a/cookbook/providers/vertexai/agent_stream.py +++ b/cookbook/providers/vertexai/agent_stream.py @@ -2,7 +2,7 @@ from typing import Iterator # noqa from phi.agent import Agent, RunResponse # noqa -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.tools.yfinance import YFinanceTools agent = Agent( diff --git a/cookbook/providers/vertexai/basic_stream.py b/cookbook/providers/vertexai/basic_stream.py index 44ca0f5e4..dcff6f78b 100644 --- a/cookbook/providers/vertexai/basic_stream.py +++ b/cookbook/providers/vertexai/basic_stream.py @@ -1,6 +1,6 @@ from typing import Iterator # noqa from phi.agent import Agent, RunResponse # noqa -from phi.model.google import Gemini +from phi.model.vertexai import Gemini agent = Agent(model=Gemini(id="gemini-1.5-flash"), markdown=True) diff --git a/cookbook/providers/vertexai/data_analyst.py b/cookbook/providers/vertexai/data_analyst.py index d5802d7b7..e03dddd4b 100644 --- a/cookbook/providers/vertexai/data_analyst.py +++ b/cookbook/providers/vertexai/data_analyst.py @@ -2,7 +2,7 @@ from textwrap import dedent from phi.agent import Agent -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.tools.duckdb import DuckDbTools duckdb_tools = DuckDbTools(create_tables=False, export_tables=False, summarize_tables=False) diff --git a/cookbook/providers/vertexai/finance_agent.py b/cookbook/providers/vertexai/finance_agent.py index 4876f12ac..b0f4ea90b 100644 --- a/cookbook/providers/vertexai/finance_agent.py +++ b/cookbook/providers/vertexai/finance_agent.py @@ -1,7 +1,7 @@ """Run `pip install yfinance` to install dependencies.""" from phi.agent import Agent -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.tools.yfinance import YFinanceTools agent = Agent( diff --git a/cookbook/providers/vertexai/knowledge.py b/cookbook/providers/vertexai/knowledge.py index bdf7299c7..2574311e1 100644 --- a/cookbook/providers/vertexai/knowledge.py +++ b/cookbook/providers/vertexai/knowledge.py @@ -1,7 +1,7 @@ """Run `pip install duckduckgo-search sqlalchemy pgvector pypdf openai google.generativeai` to install dependencies.""" from phi.agent import Agent -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.knowledge.pdf import PDFUrlKnowledgeBase from phi.vectordb.pgvector import PgVector diff --git a/cookbook/providers/vertexai/storage.py b/cookbook/providers/vertexai/storage.py index 8fae33515..c8f2c8c21 100644 --- a/cookbook/providers/vertexai/storage.py +++ b/cookbook/providers/vertexai/storage.py @@ -1,7 +1,7 @@ """Run `pip install duckduckgo-search sqlalchemy google.generativeai` to install dependencies.""" from phi.agent import Agent -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.tools.duckduckgo import DuckDuckGo from phi.storage.agent.postgres import PgAgentStorage diff --git a/cookbook/providers/vertexai/structured_output.py b/cookbook/providers/vertexai/structured_output.py index 377f45544..de8d3abc2 100644 --- a/cookbook/providers/vertexai/structured_output.py +++ b/cookbook/providers/vertexai/structured_output.py @@ -2,7 +2,7 @@ from rich.pretty import pprint # noqa from pydantic import BaseModel, Field from phi.agent import Agent, RunResponse # noqa -from phi.model.google import Gemini +from phi.model.vertexai import Gemini class MovieScript(BaseModel): diff --git a/cookbook/providers/vertexai/web_search.py b/cookbook/providers/vertexai/web_search.py index e3746c5aa..267fe1910 100644 --- a/cookbook/providers/vertexai/web_search.py +++ b/cookbook/providers/vertexai/web_search.py @@ -1,7 +1,7 @@ """Run `pip install duckduckgo-search` to install dependencies.""" from phi.agent import Agent -from phi.model.google import Gemini +from phi.model.vertexai import Gemini from phi.tools.duckduckgo import DuckDuckGo agent = Agent(model=Gemini(id="gemini-1.5-flash"), tools=[DuckDuckGo()], show_tool_calls=True, markdown=True) diff --git a/phi/model/vertexai/gemini.py b/phi/model/vertexai/gemini.py index c80ae8b0f..4f3702d7a 100644 --- a/phi/model/vertexai/gemini.py +++ b/phi/model/vertexai/gemini.py @@ -17,76 +17,246 @@ GenerationResponse, FunctionDeclaration, Tool as GeminiTool, - Candidate as GenerationResponseCandidate, - Content as GenerationResponseContent, - Part as GenerationResponsePart, + Candidate, + Content, + Part, ) + from google.cloud.aiplatform_v1beta1.types.prediction_service import GenerateContentResponse + + UsageMetadata = GenerateContentResponse.UsageMetadata + except ImportError: logger.error("`google-cloud-aiplatform` not installed") raise + @dataclass class MessageData: response_content: str = "" - response_block = None + response_block: Content = None + response_candidates: Optional[List[Candidate]] = None response_role: Optional[str] = None response_parts: Optional[List] = None response_tool_calls: List[Dict[str, Any]] = field(default_factory=list) - response_usage = None + response_usage: Optional[UsageMetadata] = None + response_tool_call_block = None @dataclass class Metrics: - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - total_tokens: Optional[int] = None + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 time_to_first_token: Optional[float] = None response_timer: Timer = field(default_factory=Timer) + def log(self): + logger.debug("**************** METRICS START ****************") + if self.time_to_first_token is not None: + logger.debug(f"* Time to first token: {self.time_to_first_token:.4f}s") + logger.debug(f"* Time to generate response: {self.response_timer.elapsed:.4f}s") + logger.debug(f"* Tokens per second: {self.output_tokens / self.response_timer.elapsed:.4f} tokens/s") + logger.debug(f"* Input tokens: {self.input_tokens}") + logger.debug(f"* Output tokens: {self.output_tokens}") + logger.debug(f"* Total tokens: {self.total_tokens}") + logger.debug("**************** METRICS END ******************") + + class Gemini(Model): name: str = "Gemini" model: str = "gemini-1.5-flash-002" provider: str = "VertexAI" + # Request parameters generation_config: Optional[Any] = None safety_settings: Optional[Any] = None - function_declarations: Optional[List[FunctionDeclaration]] = None generative_model_request_params: Optional[Dict[str, Any]] = None - generative_model: Optional[GenerativeModel] = None + function_declarations: Optional[List[FunctionDeclaration]] = None + + # Gemini client + client: Optional[GenerativeModel] = None def get_client(self) -> GenerativeModel: - if self.generative_model is None: - self.generative_model = GenerativeModel(model_name=self.model, **self.request_kwargs) - return self.generative_model + """ + Returns a GenerativeModel client. + + Returns: + GenerativeModel: GenerativeModel client. + """ + if self.client is None: + self.client = GenerativeModel(model_name=self.model, **self.request_kwargs) + return self.client @property def request_kwargs(self) -> Dict[str, Any]: + """ + Returns the request parameters for the generative model. + + Returns: + Dict[str, Any]: Request parameters for the generative model. + """ _request_params: Dict[str, Any] = {} if self.generation_config: _request_params["generation_config"] = self.generation_config if self.safety_settings: _request_params["safety_settings"] = self.safety_settings if self.generative_model_request_params: - _request_params.update(self.generative_model__request_params) + _request_params.update(self.generative_model_request_params) if self.function_declarations: _request_params["tools"] = [GeminiTool(function_declarations=self.function_declarations)] return _request_params - def convert_messages_to_contents(self, messages: List[Message]) -> List[Any]: - _contents: List[Any] = [] - for m in messages: - if isinstance(m.content, str): - _contents.append(m.content) - elif isinstance(m.content, list): - _contents.extend(m.content) - return _contents + def _format_messages(self, messages: List[Message]) -> List[Content]: + """ + Converts a list of Message objects to Gemini-compatible Content objects. + + Args: + messages: List of Message objects containing various types of content + + Returns: + List of Content objects formatted for Gemini's API + """ + formatted_messages: List[Content] = [] + + for msg in messages: + if hasattr(msg, "response_tool_call_block"): + formatted_messages.append(Content(role=msg.role, parts=msg.response_tool_call_block.parts)) + continue + if msg.role == "tool" and hasattr(msg, "tool_call_result"): + formatted_messages.append(msg.tool_call_result) + continue + if isinstance(msg.content, str): + parts = [Part.from_text(msg.content)] + elif isinstance(msg.content, list): + parts = [Part.from_text(part) for part in msg.content if isinstance(part, str)] + else: + parts = [] + role = "model" if msg.role == "system" else "user" if msg.role == "tool" else msg.role + + formatted_messages.append(Content(role=role, parts=parts)) + + return formatted_messages + + def _format_functions(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts function parameters to a Gemini-compatible format. + + Args: + params (Dict[str, Any]): The original parameter's dictionary. + + Returns: + Dict[str, Any]: The converted parameters dictionary compatible with Gemini. + """ + formatted_params = {} + for key, value in params.items(): + if key == "properties" and isinstance(value, dict): + converted_properties = {} + for prop_key, prop_value in value.items(): + property_type = prop_value.get("type") + if isinstance(property_type, list): + # Create a copy to avoid modifying the original list + non_null_types = [t for t in property_type if t != "null"] + if non_null_types: + # Use the first non-null type + converted_type = non_null_types[0] + else: + # Default type if all types are 'null' + converted_type = "string" + else: + converted_type = property_type + + converted_properties[prop_key] = {"type": converted_type} + formatted_params[key] = converted_properties + else: + formatted_params[key] = value + return formatted_params + + def add_tool( + self, tool: Union["Tool", "Toolkit", Callable, dict, "Function"], structured_outputs: bool = False + ) -> None: + """ + Adds tools to the model. + + Args: + tool: The tool to add. Can be a Tool, Toolkit, Callable, dict, or Function. + """ + if self.function_declarations is None: + self.function_declarations = [] + + # If the tool is a Tool or Dict, log a warning. + if isinstance(tool, Tool) or isinstance(tool, Dict): + logger.warning("Tool of type 'Tool' or 'dict' is not yet supported by Gemini.") + + # If the tool is a Callable or Toolkit, add its functions to the Model + elif callable(tool) or isinstance(tool, Toolkit) or isinstance(tool, Function): + if self.functions is None: + self.functions = {} + + if isinstance(tool, Toolkit): + # For each function in the toolkit + for name, func in tool.functions.items(): + # If the function does not exist in self.functions, add to self.tools + if name not in self.functions: + self.functions[name] = func + function_declaration = FunctionDeclaration( + name=func.name, + description=func.description, + parameters=self._format_functions(func.parameters), + ) + self.function_declarations.append(function_declaration) + logger.debug(f"Function {name} from {tool.name} added to model.") + + elif isinstance(tool, Function): + if tool.name not in self.functions: + self.functions[tool.name] = tool + function_declaration = FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters=self._format_functions(tool.parameters), + ) + self.function_declarations.append(function_declaration) + logger.debug(f"Function {tool.name} added to model.") + + elif callable(tool): + try: + function_name = tool.__name__ + if function_name not in self.functions: + func = Function.from_callable(tool) + self.functions[func.name] = func + function_declaration = FunctionDeclaration( + name=func.name, + description=func.description, + parameters=self._format_functions(func.parameters), + ) + self.function_declarations.append(function_declaration) + logger.debug(f"Function '{func.name}' added to model.") + except Exception as e: + logger.warning(f"Could not add function {tool}: {e}") def invoke(self, messages: List[Message]) -> GenerationResponse: - return self.get_client().generate_content(contents=self.convert_messages_to_contents(messages)) + """ + Send a generate content request to VertexAI and return the response. + + Args: + messages: List of Message objects containing various types of content + + Returns: + GenerationResponse object containing the response content + """ + return self.get_client().generate_content(contents=self._format_messages(messages)) def invoke_stream(self, messages: List[Message]) -> Iterator[GenerationResponse]: - yield from self.client.generate_content( - contents=self.convert_messages_to_contents(messages), + """ + Send a generate content request to VertexAI and return the response. + + Args: + messages: List of Message objects containing various types of content + + Returns: + Iterator[GenerationResponse] object containing the response content + """ + yield from self.get_client().generate_content( + contents=self._format_messages(messages), stream=True, ) @@ -97,21 +267,19 @@ def _log_messages(self, messages: List[Message]) -> None: for m in messages: m.log() - def _update_usage_metrics( self, assistant_message: Message, - usage=None, - stream_usage=None, - metrics=Metrics(), + metrics: Metrics, + usage: Optional[UsageMetadata] = None, ) -> None: """ - Update the usage metrics. + Update usage metrics for the assistant message. Args: - assistant_message (Message): The assistant message. - usage (ResultGenerateContentResponse): The usage metrics. - stream_usage (Optional[StreamUsageData]): The stream usage metrics. + assistant_message: Message object containing the response content + metrics: Metrics object containing the usage metrics + usage: UsageMetadata object containing the usage metrics """ assistant_message.metrics["time"] = metrics.response_timer.elapsed self.metrics.setdefault("response_times", []).append(metrics.response_timer.elapsed) @@ -133,16 +301,16 @@ def _update_usage_metrics( assistant_message.metrics["time_to_first_token"] = metrics.time_to_first_token self.metrics.setdefault("time_to_first_token", []).append(metrics.time_to_first_token) - def _create_assistant_message(self, response: GenerationResponse, response_timer: Timer) -> Message: + def _create_assistant_message(self, response: GenerationResponse, metrics: Metrics) -> Message: """ - Create an assistant message from the model response. + Create an assistant message from the GenerationResponse. Args: - response (GenerateContentResponse): The model response. - response_timer (Timer): The response timer. + response: GenerationResponse object containing the response content + metrics: Metrics object containing the usage metrics Returns: - Message: The assistant message. + Message object containing the assistant message """ message_data = MessageData() @@ -152,6 +320,7 @@ def _create_assistant_message(self, response: GenerationResponse, response_timer message_data.response_parts = message_data.response_block.parts message_data.response_usage = response.usage_metadata + # -*- Parse response if message_data.response_parts is not None: for part in message_data.response_parts: part_dict = type(part).to_dict(part) @@ -162,6 +331,7 @@ def _create_assistant_message(self, response: GenerationResponse, response_timer # Parse function calls if "function_call" in part_dict: + message_data.response_tool_call_block = response.candidates[0].content message_data.response_tool_calls.append( { "type": "function", @@ -172,25 +342,124 @@ def _create_assistant_message(self, response: GenerationResponse, response_timer } ) - # Create assistant message + # -*- Create assistant message assistant_message = Message( role=message_data.response_role or "model", content=message_data.response_content, - parts=message_data.response_parts, + response_tool_call_block=message_data.response_tool_call_block, ) - # Update assistant message if tool calls are present + # -*- Update assistant message if tool calls are present if len(message_data.response_tool_calls) > 0: assistant_message.tool_calls = message_data.response_tool_calls - # Update usage metrics - assistant_message.metrics["time"] = response_timer.elapsed - self.metrics.setdefault("response_times", []).append(response_timer.elapsed) - self._update_usage_metrics(assistant_message, message_data.response_usage) + # -*- Update usage metrics + self._update_usage_metrics( + assistant_message=assistant_message, metrics=metrics, usage=message_data.response_usage + ) return assistant_message + def _get_function_calls_to_run( + self, + assistant_message: Message, + messages: List[Message], + ) -> List[FunctionCall]: + """ + Extracts and validates function calls from the assistant message. + + Args: + assistant_message (Message): The assistant message containing tool calls. + messages (List[Message]): The list of conversation messages. + + Returns: + List[FunctionCall]: A list of valid function calls to run. + """ + function_calls_to_run: List[FunctionCall] = [] + if assistant_message.tool_calls: + for tool_call in assistant_message.tool_calls: + _function_call = get_function_call_for_tool_call(tool_call, self.functions) + if _function_call is None: + messages.append(Message(role="tool", content="Could not find function to call.")) + continue + if _function_call.error is not None: + messages.append(Message(role="tool", content=_function_call.error)) + continue + function_calls_to_run.append(_function_call) + return function_calls_to_run + + def _format_function_call_results( + self, + function_call_results: List[Message], + messages: List[Message], + ): + """ + Processes the results of function calls and appends them to messages. + + Args: + function_call_results (List[Message]): The results from running function calls. + messages (List[Message]): The list of conversation messages. + """ + if function_call_results: + contents, parts = zip( + *[ + ( + result.content, + Part.from_function_response(name=result.tool_name, response={"content": result.content}), + ) + for result in function_call_results + ] + ) + + messages.append(Message(role="tool", content=list(contents), tool_call_result=Content(parts=list(parts)))) + + def _handle_tool_calls(self, assistant_message: Message, messages: List[Message], model_response: ModelResponse): + """ + Handle tool calls in the assistant message. + + Args: + assistant_message (Message): The assistant message. + messages (List[Message]): A list of messages. + model_response (ModelResponse): The model response. + + Returns: + Optional[ModelResponse]: The updated model response. + """ + if assistant_message.tool_calls and self.run_tools: + model_response.content = assistant_message.get_content_string() or "" + function_calls_to_run = self._get_function_calls_to_run(assistant_message, messages) + + if self.show_tool_calls: + if len(function_calls_to_run) == 1: + model_response.content += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" + elif len(function_calls_to_run) > 1: + model_response.content += "\nRunning:" + for _f in function_calls_to_run: + model_response.content += f"\n - {_f.get_call_str()}" + model_response.content += "\n\n" + + function_call_results: List[Message] = [] + for _ in self.run_function_calls( + function_calls=function_calls_to_run, + function_call_results=function_call_results, + ): + pass + + self._format_function_call_results(function_call_results, messages) + + return model_response + return None + def response(self, messages: List[Message]) -> ModelResponse: + """ + Send a generate content request to VertexAI and return the response. + + Args: + messages: List of Message objects containing various types of content + + Returns: + ModelResponse object containing the response content + """ logger.debug("---------- VertexAI Response Start ----------") self._log_messages(messages) model_response = ModelResponse() @@ -200,180 +469,151 @@ def response(self, messages: List[Message]) -> ModelResponse: response: GenerationResponse = self.invoke(messages=messages) metrics.response_timer.stop() - # response_candidates: List[GenerationResponseCandidate] = response.candidates - # response_content: GenerationResponseContent = response_candidates[0].content - # response_role = response_content.role - # response_parts: List[GenerationResponsePart] = response_content.parts - # response_text: Optional[str] = None - # response_function_calls: Optional[List[Dict[str, Any]]] = None - # - # if len(response_parts) > 1: - # logger.warning("Multiple content parts are not yet supported.") - # return "More than one response part found." - # - # _part_dict = response_parts[0].to_dict() - # if "text" in _part_dict: - # response_text = _part_dict.get("text") - # if "function_call" in _part_dict: - # if response_function_calls is None: - # response_function_calls = [] - # response_function_calls.append( - # { - # "type": "function", - # "function": { - # "name": _part_dict.get("function_call").get("name"), - # "arguments": json.dumps(_part_dict.get("function_call").get("args")), - # }, - # } - # ) - # -*- Create assistant message - assistant_message = self._create_assistant_message(response=response, response_timer=metrics.response_timer) + assistant_message = self._create_assistant_message(response=response, metrics=metrics) messages.append(assistant_message) + + # -*- Log response and metrics assistant_message.log() + metrics.log() - # # -*- Add tool calls to assistant message - # if response_function_calls is not None: - # assistant_message.tool_calls = response_function_calls - # - # # -*- Update usage metrics - # # Add response time to metrics - # assistant_message.metrics["time"] = response_timer.elapsed - # if "response_times" not in self.metrics: - # self.metrics["response_times"] = [] - # self.metrics["response_times"].append(response_timer.elapsed) - # # TODO: Add token usage to metrics - # - # # -*- Add assistant message to messages - # messages.append(assistant_message) - # assistant_message.log() - # - # # -*- Parse and run function calls - # if assistant_message.tool_calls is not None: - # final_response = "" - # function_calls_to_run: List[FunctionCall] = [] - # for tool_call in assistant_message.tool_calls: - # _tool_call_id = tool_call.get("id") - # _function_call = get_function_call_for_tool_call(tool_call, self.functions) - # if _function_call is None: - # messages.append( - # Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") - # ) - # continue - # if _function_call.error is not None: - # messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) - # continue - # function_calls_to_run.append(_function_call) - # - # if self.show_tool_calls: - # if len(function_calls_to_run) == 1: - # final_response += f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" - # elif len(function_calls_to_run) > 1: - # final_response += "\nRunning:" - # for _f in function_calls_to_run: - # final_response += f"\n - {_f.get_call_str()}" - # final_response += "\n\n" - # - # function_call_results = self.run_function_calls(function_calls_to_run) - # if len(function_call_results) > 0: - # messages.extend(function_call_results) - # # -*- Get new response using result of tool call - # final_response += self.response(messages=messages) - # return final_response + if self._handle_tool_calls(assistant_message, messages, model_response): + response_after_tool_calls = self.response(messages=messages) + if response_after_tool_calls.content is not None: + if model_response.content is None: + model_response.content = "" + model_response.content += response_after_tool_calls.content + return model_response if assistant_message.content is not None: model_response.content = assistant_message.get_content_string() + # -*- Remove tool call blocks and tool call results from messages + for m in messages: + if hasattr(m, "response_tool_call_block"): + m.response_tool_call_block = None + if hasattr(m, "tool_call_result"): + m.tool_call_result = None + logger.debug("---------- VertexAI Response End ----------") return model_response - def response_stream(self, messages: List[Message]) -> Iterator[str]: + def _handle_stream_tool_calls(self, assistant_message: Message, messages: List[Message]): + """ + Parse and run function calls and append the results to messages. + + Args: + assistant_message (Message): The assistant message containing tool calls. + messages (List[Message]): The list of conversation messages. + + Yields: + Iterator[ModelResponse]: Yields model responses during function execution. + """ + if assistant_message.tool_calls and self.run_tools: + function_calls_to_run = self._get_function_calls_to_run(assistant_message, messages) + + if self.show_tool_calls: + if len(function_calls_to_run) == 1: + yield ModelResponse(content=f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n") + elif len(function_calls_to_run) > 1: + yield ModelResponse(content="\nRunning:") + for _f in function_calls_to_run: + yield ModelResponse(content=f"\n - {_f.get_call_str()}") + yield ModelResponse(content="\n\n") + + function_call_results: List[Message] = [] + for intermediate_model_response in self.run_function_calls( + function_calls=function_calls_to_run, function_call_results=function_call_results + ): + yield intermediate_model_response + + self._format_function_call_results(function_call_results, messages) + + def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: + """ + Send a generate content request to VertexAI and return the response. + + Args: + messages: List of Message objects containing various types of content + + Yields: + Iterator[ModelResponse]: Yields model responses during function execution + """ logger.debug("---------- VertexAI Response Start ----------") - # -*- Log messages for debugging - for m in messages: - m.log() + self._log_messages(messages) + message_data = MessageData() + metrics = Metrics() - response_role: Optional[str] = None - response_function_calls: Optional[List[Dict[str, Any]]] = None - assistant_message_content = "" - response_timer = Timer() - response_timer.start() + metrics.response_timer.start() for response in self.invoke_stream(messages=messages): - # logger.debug(f"VertexAI response type: {type(response)}") - # logger.debug(f"VertexAI response: {response}") # -*- Parse response - response_candidates: List[GenerationResponseCandidate] = response.candidates - response_content: GenerationResponseContent = response_candidates[0].content - if response_role is None: - response_role = response_content.role - response_parts: List[GenerationResponsePart] = response_content.parts - _part_dict = response_parts[0].to_dict() - - # -*- Return text if present, otherwise get function call - if "text" in _part_dict: - response_text = _part_dict.get("text") - yield response_text - assistant_message_content += response_text - - # -*- Parse function calls - if "function_call" in _part_dict: - if response_function_calls is None: - response_function_calls = [] - response_function_calls.append( - { - "type": "function", - "function": { - "name": _part_dict.get("function_call").get("name"), - "arguments": json.dumps(_part_dict.get("function_call").get("args")), - }, - } - ) - - response_timer.stop() - logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") + message_data.response_block = response.candidates[0].content + if message_data.response_block is not None: + metrics.time_to_first_token = metrics.response_timer.elapsed + message_data.response_role = message_data.response_block.role + if message_data.response_block.parts: + message_data.response_parts = message_data.response_block.parts + + if message_data.response_parts is not None: + for part in message_data.response_parts: + part_dict = type(part).to_dict(part) + + # -*- Yield text if present + if "text" in part_dict: + text = part_dict.get("text") + yield ModelResponse(content=text) + message_data.response_content += text + + # -*- Skip function calls if there are no parts + if not message_data.response_block.parts and message_data.response_parts: + continue + # -*- Parse function calls + if "function_call" in part_dict: + message_data.response_tool_call_block = response.candidates[0].content + message_data.response_tool_calls.append( + { + "type": "function", + "function": { + "name": part_dict.get("function_call").get("name"), + "arguments": json.dumps(part_dict.get("function_call").get("args")), + }, + } + ) + message_data.response_usage = response.usage_metadata + + metrics.response_timer.stop() # -*- Create assistant message - assistant_message = Message(role=response_role or "assistant") - # -*- Add content to assistant message - if assistant_message_content != "": - assistant_message.content = assistant_message_content - # -*- Add tool calls to assistant message - if response_function_calls is not None: - assistant_message.tool_calls = response_function_calls + assistant_message = Message( + role=message_data.response_role or "assistant", + content=message_data.response_content, + response_tool_call_block=message_data.response_tool_call_block, + ) + + # Update assistant message if tool calls are present + if len(message_data.response_tool_calls) > 0: + assistant_message.tool_calls = message_data.response_tool_calls + + self._update_usage_metrics( + assistant_message=assistant_message, metrics=metrics, usage=message_data.response_usage + ) # -*- Add assistant message to messages messages.append(assistant_message) - assistant_message.log() - # -*- Parse and run function calls - if assistant_message.tool_calls is not None: - function_calls_to_run: List[FunctionCall] = [] - for tool_call in assistant_message.tool_calls: - _tool_call_id = tool_call.get("id") - _function_call = get_function_call_for_tool_call(tool_call, self.functions) - if _function_call is None: - messages.append( - Message(role="tool", tool_call_id=_tool_call_id, content="Could not find function to call.") - ) - continue - if _function_call.error is not None: - messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error)) - continue - function_calls_to_run.append(_function_call) - - if self.show_tool_calls: - if len(function_calls_to_run) == 1: - yield f"\n - Running: {function_calls_to_run[0].get_call_str()}\n\n" - elif len(function_calls_to_run) > 1: - yield "\nRunning:" - for _f in function_calls_to_run: - yield f"\n - {_f.get_call_str()}" - yield "\n\n" + # -*- Log response and metrics + assistant_message.log() + metrics.log() - function_call_results = self.run_function_calls(function_calls_to_run) - if len(function_call_results) > 0: - messages.extend(function_call_results) - # -*- Yield new response using results of tool calls + if assistant_message.tool_calls is not None and len(assistant_message.tool_calls) > 0 and self.run_tools: + yield from self._handle_stream_tool_calls(assistant_message, messages) yield from self.response_stream(messages=messages) + + # -*- Remove tool call blocks and tool call results from messages + for m in messages: + if hasattr(m, "response_tool_call_block"): + m.response_tool_call_block = None + if hasattr(m, "tool_call_result"): + m.tool_call_result = None logger.debug("---------- VertexAI Response End ----------")