From 553142d671a93668be69b73240f813386f39b08e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 7 Oct 2024 21:46:01 +0200 Subject: [PATCH] fix tool calling custom rag pattern --- .../e2e-gen-ai-app-starter-kit/app/README.md | 8 +- .../app/patterns/custom_rag_qa/chain.py | 98 ++++++++++++--- .../app/patterns/custom_rag_qa/templates.py | 54 +++----- .../app/utils/input_types.py | 58 +-------- .../app/utils/output_types.py | 30 +---- .../notebooks/getting_started.ipynb | 116 ++++++++++++++---- .../streamlit/side_bar.py | 2 +- .../streamlit/streamlit_app.py | 15 +-- .../utils/{utils.py => chat_utils.py} | 7 +- .../streamlit/utils/stream_handler.py | 3 +- 10 files changed, 213 insertions(+), 178 deletions(-) rename gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/{utils.py => chat_utils.py} (95%) diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/README.md b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/README.md index f008ec635b1..247f1422ce8 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/README.md +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/README.md @@ -19,15 +19,13 @@ This folder implements a chatbot application using FastAPI, and Google Cloud ser ### 1. Default Chain -The default chain is a simple conversational bot that produces recipes based on user questions. It uses the Gemini 1.5 Flash model. +The default chain is a simple conversational bot that produces recipes based on user questions. ### 2. Custom RAG QA -A pythonic RAG (Retrieval-Augmented Generation) chain designed for maximum flexibility in orchestrating different components. It includes: +A RAG (Retrieval-Augmented Generation) chain using Python for orchestration and base LangChain components. The chain demonstrates how to create a production-grade application with full control over the orchestration process. -- Query rewriting -- Document retrieval and ranking -- LLM-based response generation +This approach offers maximum flexibility in the orchestration of steps and allows for seamless integration with other SDK frameworks such as [Vertex AI SDK](https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk) and [LangChain](https://python.langchain.com/), retaining the support to emit `astream_events` [API compatible events](https://python.langchain.com/docs/how_to/streaming/#using-stream-events). ### 3. LangGraph Dummy Agent diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/chain.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/chain.py index ede9f376631..9c304b6821b 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/chain.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/chain.py @@ -12,22 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -# mypy: disable-error-code="arg-type" +# mypy: disable-error-code="arg-type,attr-defined" import logging -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, List import google import vertexai +from langchain.schema import Document +from langchain.tools import tool +from langchain_core.messages import ToolMessage from langchain_google_community.vertex_rank import VertexAIRank from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings -from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template +from app.patterns.custom_rag_qa.templates import ( + inspect_conversation_template, + rag_template, + template_docs, +) from app.patterns.custom_rag_qa.vector_store import get_vector_store -from app.utils.input_types import extract_human_ai_messages from app.utils.output_types import ( OnChatModelStreamEvent, OnToolEndEvent, - create_on_tool_end_event_from_retrieval, custom_chain, ) @@ -39,11 +44,16 @@ # Initialize logging logging.basicConfig(level=logging.INFO) +# Initialize Google Cloud and Vertex AI credentials, project_id = google.auth.default() vertexai.init(project=project_id) + +# Set up embedding model and vector store embedding = VertexAIEmbeddings(model_name=EMBEDDING_MODEL) vector_store = get_vector_store(embedding=embedding) retriever = vector_store.as_retriever(search_kwargs={"k": 20}) + +# Initialize document compressor compressor = VertexAIRank( project_id=project_id, location_id="global", @@ -51,9 +61,42 @@ title_field="id", top_n=TOP_K, ) + + +@tool +def retrieve_docs(query: str) -> List[Document]: + """ + Useful for retrieving relevant documents based on a query. + Use this when you need additional information to answer a question. + + Args: + query (str): The user's question or search query. + + Returns: + List[Document]: A list of the top-ranked Document objects, limited to TOP_K (5) results. + """ + retrieved_docs = retriever.invoke(query) + ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query) + return ranked_docs + + +@tool +def should_continue() -> None: + """ + Use this tool if you determine that you have enough context to respond to the questions of the user. + """ + return None + + +# Initialize language model llm = ChatVertexAI(model=LLM_MODEL, temperature=0, max_tokens=1024) -query_gen = query_rewrite_template | llm +# Set up conversation inspector +inspect_conversation = inspect_conversation_template | llm.bind_tools( + [retrieve_docs, should_continue], tool_choice="any" +) + +# Set up response chain response_chain = rag_template | llm @@ -62,24 +105,41 @@ def chain( input: Dict[str, Any], **kwargs: Any ) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]: """ - Implements a RAG QA chain. Decorated with `custom_chain` to offer LangChain compatible astream_events - and invoke interface and OpenTelemetry tracing. + Implement a RAG QA chain with tool calls. + + This function is decorated with `custom_chain` to offer LangChain compatible + astream_events, support for synchronous invocation through the `invoke` method, + and OpenTelemetry tracing. """ - # Separate conversation messages from tool calls - input["messages"] = extract_human_ai_messages(input["messages"]) + # Inspect conversation and determine next action + inspection_result = inspect_conversation.invoke(input) + tool_call_result = inspection_result.tool_calls[0] - # Generate optimized query - query = query_gen.invoke(input).content + # Execute the appropriate tool based on the inspection result + if tool_call_result["name"] == "retrieve_docs": + # Retrieve relevant documents + docs = retrieve_docs.invoke(tool_call_result["args"]) + # Format the retrieved documents + formatted_docs = template_docs.format(docs=docs) + # Create a ToolMessage with the formatted documents + tool_message = ToolMessage( + tool_call_id=tool_call_result["name"], + name=tool_call_result["name"], + content=formatted_docs, + artifact=docs, + ) + else: + # If no documents need to be retrieved, continue with the conversation + tool_message = should_continue.invoke(tool_call_result) - # Retrieve and rank documents - retrieved_docs = retriever.invoke(query) - ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query) + # Update input messages with new information + input["messages"] = input["messages"] + [inspection_result, tool_message] # Yield tool results metadata - yield create_on_tool_end_event_from_retrieval(query=query, docs=ranked_docs) + yield OnToolEndEvent( + data={"input": tool_call_result["args"], "output": tool_message} + ) # Stream LLM response - for chunk in response_chain.stream( - input={"messages": input["messages"], "relevant_documents": ranked_docs} - ): + for chunk in response_chain.stream(input=input): yield OnChatModelStreamEvent(data={"chunk": chunk}) diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/templates.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/templates.py index d91b9fb76d1..40a4f93b933 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/templates.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/patterns/custom_rag_qa/templates.py @@ -1,25 +1,26 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from langchain_core.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + PromptTemplate, +) -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +template_docs = PromptTemplate.from_template( + """## Context provided: +{% for doc in docs%} + +{{ doc.page_content | safe }} + +{% endfor %} +""", + template_format="jinja2", +) -query_rewrite_template = ChatPromptTemplate.from_messages( +inspect_conversation_template = ChatPromptTemplate.from_messages( [ ( "system", - "Rewrite a query to a semantic search engine using the current conversation. " - "Provide only the rewritten query as output.", + """You are an AI assistant tasked with analyzing the conversation " +and determining the best course of action.""", ), MessagesPlaceholder(variable_name="messages"), ] @@ -29,27 +30,12 @@ [ ( "system", - """You are an AI assistant for question-answering tasks. Follow these guidelines: -1. Use only the provided context to answer the question. -2. Give clear, accurate responses based on the information available. -3. If the context is insufficient, state: "I don't have enough information to answer this question." -4. Don't make assumptions or add information beyond the given context. -5. Use bullet points or lists for clarity when appropriate. -6. Briefly explain technical terms if necessary. -7. Use quotation marks for direct quotes from the context. + """You are an AI assistant for question-answering tasks. Answer to the best of your ability using the context provided. If you're unsure, it's better to acknowledge limitations than to speculate. - -## Context provided: -{% for doc in relevant_documents%} - -{{ doc.page_content | safe }} - -{% endfor %} """, ), MessagesPlaceholder(variable_name="messages"), - ], - template_format="jinja2", + ] ) diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/input_types.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/input_types.py index dbc8e046999..49be86317c5 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/input_types.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/input_types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, List, Literal, Optional, Union from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from pydantic import BaseModel, Field @@ -54,59 +54,3 @@ def default_serialization(obj: Any) -> Any: """ if isinstance(obj, BaseModel): return obj.model_dump() - - -def extract_human_ai_messages( - messages: List[Union[Dict[str, Any], BaseModel]] -) -> List[Dict[str, Any]]: - """ - Extract AI and human messages with non-empty content from a list of messages. - The function will remove all messages relative to tool calls (Empty AI Messages - with tool calls and ToolMessages). - - Args: - messages (List[Union[Dict[str, Any], BaseModel]]): A list of message objects. - - Returns: - List[Dict[str, Any]]: A list of extracted AI and human messages with - non-empty content. - """ - extracted_messages = [] - for message in messages: - if isinstance(message, BaseModel): - message = message.model_dump() - - is_valid_type = message.get("type") in ["human", "ai"] - has_content = bool(message.get("content")) - - if is_valid_type and has_content: - extracted_messages.append(message) - - return extracted_messages - - -def extract_tool_calls_and_messages( - messages: List[Union[Dict[str, Any], BaseModel]] -) -> List[Dict[str, Any]]: - """ - Extract AI Messages with tool calls and ToolMessages from a list of messages. - AI Messages with tool calls define tool inputs, while ToolMessages contain outputs. - - Args: - messages (List[Union[Dict[str, Any], BaseModel]]): A list of message objects. - - Returns: - List[Dict[str, Any]]: A list of extracted tool calls and . - """ - extracted_messages = [] - for message in messages: - if isinstance(message, BaseModel): - message = message.model_dump() - - is_tool_message = message.get("type") == "tool" - is_ai_with_tool_call = message.get("type") == "ai" and message.get("tool_calls") - - if is_tool_message or is_ai_with_tool_call: - extracted_messages.append(message) - - return extracted_messages diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/output_types.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/output_types.py index ae07a2ebeeb..a40ea471ce5 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/output_types.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/utils/output_types.py @@ -15,9 +15,8 @@ import uuid from functools import wraps from types import GeneratorType -from typing import Any, AsyncGenerator, Callable, Dict, List, Literal +from typing import Any, AsyncGenerator, Callable, Dict, Literal -from langchain_core.documents import Document from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage from pydantic import BaseModel, Field from traceloop.sdk import TracerWrapper @@ -66,32 +65,6 @@ class EndEvent(BaseModel): event: Literal["end"] = "end" -def create_on_tool_end_event_from_retrieval( - query: str, - docs: List[Document], - tool_call_id: str = "retriever", - name: str = "retriever", -) -> OnToolEndEvent: - """ - Create a LangChain Astream events v2 compatible on_tool_end_event from a retrieval process. - - Args: - query (str): The query used for retrieval. - docs (List[Document]): The retrieved documents. - tool_call_id (str, optional): The ID of the tool call. Defaults to "retriever". - name (str, optional): The name of the tool. Defaults to "retriever". - - Returns: - OnToolEndEvent: An event representing the end of the retrieval tool execution. - """ - ranked_docs_tool_output = ToolMessage( - tool_call_id=tool_call_id, content=[doc.model_dump() for doc in docs], name=name - ) - return OnToolEndEvent( - data=ToolData(input={"query": query}, output=ranked_docs_tool_output) - ) - - class CustomChain: """A custom chain class that wraps a callable function.""" @@ -111,6 +84,7 @@ async def astream_events(self, *args: Any, **kwargs: Any) -> AsyncGenerator: func = self.func gen: GeneratorType = func(*args, **kwargs) + for event in gen: yield event.model_dump() diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/notebooks/getting_started.ipynb b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/notebooks/getting_started.ipynb index 5bffd36a17b..bf14de4b4c6 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/notebooks/getting_started.ipynb +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/notebooks/getting_started.ipynb @@ -247,7 +247,7 @@ "\n", "import vertexai\n", "\n", - "PROJECT_ID = \"production-ai-template\" # @param {type:\"string\", isTemplate: true}\n", + "PROJECT_ID = \"[your-project-id]\" # @param {type:\"string\", isTemplate: true}\n", "if PROJECT_ID == \"[your-project-id]\":\n", " PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n", "\n", @@ -290,8 +290,9 @@ "import json\n", "import pandas as pd\n", "import yaml\n", - "from typing import Any, Dict, Iterator, Literal\n", + "from typing import Any, Dict, Iterator, Literal, List\n", "\n", + "from langchain_core.messages import ToolMessage\n", "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n", "from langchain_core.runnables import RunnableConfig\n", "from langchain_core.tools import tool\n", @@ -299,11 +300,16 @@ "from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings\n", "from langgraph.graph import END, MessagesState, StateGraph\n", "from langgraph.prebuilt import ToolNode\n", + "from langchain.schema import Document\n", "from google.cloud import aiplatform\n", "from vertexai.evaluation import CustomMetric, EvalTask\n", "\n", "from app.eval.utils import batch_generate_messages, generate_multiturn_history\n", - "from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template\n", + "from app.patterns.custom_rag_qa.templates import (\n", + " inspect_conversation_template,\n", + " rag_template,\n", + " template_docs,\n", + ")\n", "from app.patterns.custom_rag_qa.vector_store import get_vector_store\n", "from app.utils.output_types import OnChatModelStreamEvent, OnToolEndEvent, custom_chain" ] @@ -416,7 +422,7 @@ "id": "abc296e9da88" }, "source": [ - "### Leveraging LangChain LCEL for Efficient Chain Composition\n", + "### Leverage LangChain LCEL\n", "\n", "LangChain Expression Language (LCEL) provides a declarative approach to composing chains seamlessly. Key benefits include:\n", "\n", @@ -470,7 +476,7 @@ }, "outputs": [], "source": [ - "input_message = {\"messages\": [(\"user\", \"Can you provide me a Lasagne recipe?\")]}\n", + "input_message = {\"messages\": [(\"human\", \"Can you provide me a Lasagne recipe?\")]}\n", "\n", "async for event in chain.astream_events(input=input_message, version=\"v2\"):\n", " if event[\"event\"] in SUPPORTED_EVENTS:\n", @@ -586,7 +592,7 @@ }, "outputs": [], "source": [ - "input_message = {\"messages\": [(\"user\", \"What is the weather like in NY?\")]}\n", + "input_message = {\"messages\": [(\"human\", \"What is the weather like in NY?\")]}\n", "\n", "async for event in chain.astream_events(input=input_message, version=\"v2\"):\n", " if event[\"event\"] in SUPPORTED_EVENTS:\n", @@ -620,9 +626,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "ea13644948d2" - }, + "metadata": {}, "outputs": [], "source": [ "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-002\", temperature=0)\n", @@ -639,31 +643,83 @@ " top_n=5,\n", ")\n", "\n", - "query_gen = query_rewrite_template | llm\n", + "\n", + "@tool\n", + "def retrieve_docs(query: str) -> List[Document]:\n", + " \"\"\"\n", + " Useful for retrieving relevant documents based on a query.\n", + " Use this when you need additional information to answer a question.\n", + "\n", + " Args:\n", + " query (str): The user's question or search query.\n", + "\n", + " Returns:\n", + " List[Document]: A list of the top-ranked Document objects, limited to TOP_K (5) results.\n", + " \"\"\"\n", + " retrieved_docs = retriever.invoke(query)\n", + " ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)\n", + " return ranked_docs\n", + "\n", + "\n", + "@tool\n", + "def should_continue() -> None:\n", + " \"\"\"\n", + " Use this tool if you determine that you have enough context to respond to the questions of the user.\n", + " \"\"\"\n", + " return None\n", + "\n", + "\n", + "# Set up conversation inspector\n", + "inspect_conversation = inspect_conversation_template | llm.bind_tools(\n", + " [retrieve_docs, should_continue], tool_choice=\"any\"\n", + ")\n", + "\n", + "# Set up response chain\n", "response_chain = rag_template | llm\n", "\n", "\n", "@custom_chain\n", "def chain(\n", - " input: Dict[str, Any], **kwargs\n", + " input: Dict[str, Any], **kwargs: Any\n", ") -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n", " \"\"\"\n", - " Implements a RAG QA chain. Decorated with `custom_chain` to offer LangChain compatible astream_events\n", - " and invoke interface and OpenTelemetry tracing.\n", - " \"\"\"\n", - " # Generate optimized query\n", - " query = query_gen.invoke(input=input).content\n", + " Implement a RAG QA chain with tool calls.\n", "\n", - " # Retrieve and rank documents\n", - " retrieved_docs = retriever.get_relevant_documents(query)\n", - " ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)\n", + " This function is decorated with `custom_chain` to offer LangChain compatible\n", + " astream_events, support for synchronous invocation through the `invoke` method,\n", + " and OpenTelemetry tracing.\n", + " \"\"\"\n", + " # Inspect conversation and determine next action\n", + " inspection_result = inspect_conversation.invoke(input)\n", + " tool_call_result = inspection_result.tool_calls[0]\n", + "\n", + " # Execute the appropriate tool based on the inspection result\n", + " if tool_call_result[\"name\"] == \"retrieve_docs\":\n", + " # Retrieve relevant documents\n", + " docs = retrieve_docs.invoke(tool_call_result[\"args\"])\n", + " # Format the retrieved documents\n", + " formatted_docs = template_docs.format(docs=docs)\n", + " # Create a ToolMessage with the formatted documents\n", + " tool_message = ToolMessage(\n", + " tool_call_id=tool_call_result[\"name\"],\n", + " name=tool_call_result[\"name\"],\n", + " content=formatted_docs,\n", + " artifact=docs,\n", + " )\n", + " else:\n", + " # If no documents need to be retrieved, continue with the conversation\n", + " tool_message = should_continue.invoke(tool_call_result)\n", + "\n", + " # Update input messages with new information\n", + " input[\"messages\"] = input[\"messages\"] + [inspection_result, tool_message]\n", "\n", " # Yield tool results metadata\n", - " yield OnToolEndEvent(data={\"input\": {\"query\": query}, \"output\": ranked_docs})\n", + " yield OnToolEndEvent(\n", + " data={\"input\": tool_call_result[\"args\"], \"output\": tool_message}\n", + " )\n", + "\n", " # Stream LLM response\n", - " for chunk in response_chain.stream(\n", - " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n", - " ):\n", + " for chunk in response_chain.stream(input=input):\n", " yield OnChatModelStreamEvent(data={\"chunk\": chunk})" ] }, @@ -705,7 +761,7 @@ }, "outputs": [], "source": [ - "input_message = {\"messages\": [(\"user\", \"What is MLOps?\")]}\n", + "input_message = {\"messages\": [(\"human\", \"What is MLOps?\")]}\n", "\n", "async for event in chain.astream_events(input=input_message, version=\"v2\"):\n", " if event[\"event\"] in SUPPORTED_EVENTS:\n", @@ -1097,6 +1153,18 @@ "kernelspec": { "display_name": "Python 3", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" } }, "nbformat": 4, diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/side_bar.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/side_bar.py index fd1c79e304d..015a0c17b16 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/side_bar.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/side_bar.py @@ -16,12 +16,12 @@ import uuid from typing import Any +from utils.chat_utils import save_chat from utils.multimodal_utils import ( HELP_GCS_CHECKBOX, HELP_MESSAGE_MULTIMODALITY, upload_files_to_gcs, ) -from utils.utils import save_chat EMPTY_CHAT_NAME = "Empty chat" NUM_CHAT_IN_RECENT = 3 diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/streamlit_app.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/streamlit_app.py index 0f5dc26c0a3..a469a603953 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/streamlit_app.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/streamlit_app.py @@ -73,7 +73,7 @@ def display_messages() -> None: elif "tool_calls" in message and message["tool_calls"]: tool_call_input = handle_tool_call(message) elif message["type"] == "tool" and tool_call_input is not None: - display_tool_output(tool_call_input, message["content"]) + display_tool_output(tool_call_input, message) tool_call_input = None else: st.error(f"Unexpected message type: {message['type']}") @@ -165,7 +165,8 @@ def handle_user_input(side_bar: SideBar) -> None: display_user_input(parts) generate_ai_response( - side_bar.url_input_field, side_bar.should_authenticate_request + url_input_field=side_bar.url_input_field, + should_authenticate_request=side_bar.should_authenticate_request, ) update_chat_title() if len(parts) > 1: @@ -207,7 +208,7 @@ def update_chat_title() -> None: ) -def display_feedback() -> None: +def display_feedback(side_bar: SideBar) -> None: if st.session_state.run_id is not None: feedback = streamlit_feedback( feedback_type="faces", @@ -216,8 +217,8 @@ def display_feedback() -> None: ) if feedback is not None: client = Client( - url=st.session_state.url_input_field, - authenticate_request=st.session_state.should_authenticate_request, + url=side_bar.url_input_field, + authenticate_request=side_bar.should_authenticate_request, ) client.log_feedback( feedback_dict=feedback, @@ -231,8 +232,8 @@ def main() -> None: side_bar = SideBar(st=st) side_bar.init_side_bar() display_messages() - handle_user_input(side_bar) - display_feedback() + handle_user_input(side_bar=side_bar) + display_feedback(side_bar=side_bar) if __name__ == "__main__": diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/utils.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/chat_utils.py similarity index 95% rename from gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/utils.py rename to gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/chat_utils.py index a198b032d5f..0ec34b0a6c2 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/utils.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/chat_utils.py @@ -22,9 +22,12 @@ def preprocess_text(text: str) -> str: - if text[0] == "\n": + if not text: + return text + + if text.startswith("\n"): text = text[1:] - if text[-1] == "\n": + if text.endswith("\n"): text = text[:-1] return text diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/stream_handler.py b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/stream_handler.py index 3749837492f..c28c161c389 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/stream_handler.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-kit/streamlit/utils/stream_handler.py @@ -22,6 +22,7 @@ import requests from google.auth.exceptions import DefaultCredentialsError from langchain_core.messages import AIMessage, ToolMessage +from utils.multimodal_utils import format_content import streamlit as st @@ -114,7 +115,7 @@ def __init__(self, st: Any, initial_text: str = "") -> None: def new_token(self, token: str) -> None: """Add a new token to the main text display.""" self.text += token - self.container.markdown(self.text, unsafe_allow_html=True) + self.container.markdown(format_content(self.text), unsafe_allow_html=True) def new_status(self, status_update: str) -> None: """Add a new status update to the tool calls expander."""