From 64a9ba1081262e3192af7d15b91eda44b6a4c5aa Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 17 Sep 2024 12:45:30 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20post-p?= =?UTF-8?q?rocessor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- .../app/chain.py | 15 ++- .../app/eval/utils.py | 29 ++--- .../app/patterns/custom_rag_qa/chain.py | 17 +-- .../app/patterns/custom_rag_qa/templates.py | 31 ++++-- .../patterns/custom_rag_qa/vector_store.py | 4 +- .../patterns/langgraph_dummy_agent/chain.py | 23 ++-- .../app/server.py | 53 +++++----- .../app/utils/input_types.py | 9 +- .../app/utils/output_types.py | 18 +--- .../app/utils/tracing.py | 38 +++---- .../notebooks/getting_started.ipynb | 40 +++++-- .../streamlit/side_bar.py | 82 ++++++++------ .../streamlit/streamlit_app.py | 100 ++++++++---------- .../streamlit/style/app_markdown.py | 2 +- .../streamlit/utils/local_chat_history.py | 30 +++--- .../streamlit/utils/message_editing.py | 34 ++++-- .../streamlit/utils/multimodal_utils.py | 78 ++++++++------ .../streamlit/utils/stream_handler.py | 72 ++++++------- .../streamlit/utils/title_summary.py | 2 +- .../streamlit/utils/utils.py | 4 +- .../patterns/test_langgraph_dummy_agent.py | 21 ++-- .../tests/integration/patterns/test_rag_qa.py | 23 ++-- .../tests/integration/test_chain.py | 23 ++-- .../tests/integration/test_server_e2e.py | 42 +++++--- .../tests/load_test/load_test.py | 12 +-- .../tests/unit/test_server.py | 22 ++-- .../unit/test_utils/test_tracing_exporter.py | 47 ++++---- 27 files changed, 487 insertions(+), 384 deletions(-) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/chain.py index c126d460413..e2d2d7d4c1a 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/chain.py @@ -24,16 +24,21 @@ } llm = ChatVertexAI( - model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=1024, - safety_settings=safety_settings + model_name="gemini-1.5-flash-001", + temperature=0, + max_output_tokens=1024, + safety_settings=safety_settings, ) template = ChatPromptTemplate.from_messages( [ - ("system", """You are a conversational bot that produce recipes for users based - on a question."""), - MessagesPlaceholder(variable_name="messages") + ( + "system", + """You are a conversational bot that produce recipes for users based + on a question.""", + ), + MessagesPlaceholder(variable_name="messages"), ] ) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py index 9892381ef63..487088acdf3 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from concurrent.futures import ThreadPoolExecutor +from functools import partial import glob import logging import os -from concurrent.futures import ThreadPoolExecutor -from functools import partial from typing import Any, Callable, Dict, Iterator, List import nest_asyncio import pandas as pd -import yaml from tqdm import tqdm +import yaml nest_asyncio.apply() + def load_chats(path: str) -> List[Dict[str, Any]]: """ Loads a list of chats from a directory or file. @@ -44,6 +45,7 @@ def load_chats(path: str) -> List[Dict[str, Any]]: chats = chats + chats_in_file return chats + def pairwise(iterable: List[Any]) -> Iterator[tuple[Any, Any]]: """Creates an iterable with tuples paired together e.g s -> (s0, s1), (s2, s3), (s4, s5), ... @@ -81,11 +83,9 @@ def generate_multiturn_history(df: pd.DataFrame) -> pd.DataFrame: message = { "human_message": human_message, "ai_message": ai_message, - "conversation_history": conversation_history + "conversation_history": conversation_history, } - conversation_history = conversation_history + [ - human_message, ai_message - ] + conversation_history = conversation_history + [human_message, ai_message] processed_messages.append(message) return pd.DataFrame(processed_messages) @@ -103,7 +103,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str Args: row (tuple[int, Dict[str, Any]]): A tuple containing the index and a dictionary with message data, including: - - "conversation_history" (List[str]): Optional. List of previous + - "conversation_history" (List[str]): Optional. List of previous messages in the conversation. - "human_message" (str): The current human message. @@ -118,7 +118,9 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str - "response_obj" (Any): The usage metadata of the response from the callable. """ index, message = row - messages = message["conversation_history"] if "conversation_history" in message else [] + messages = ( + message["conversation_history"] if "conversation_history" in message else [] + ) messages.append(message["human_message"]) input_callable = {"messages": messages} response = callable.invoke(input_callable) @@ -130,7 +132,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str def batch_generate_messages( messages: pd.DataFrame, callable: Callable[[List[Dict[str, Any]]], Dict[str, Any]], - max_workers: int = 4 + max_workers: int = 4, ) -> pd.DataFrame: """Generates AI responses to user messages using a provided callable. @@ -152,7 +154,7 @@ def batch_generate_messages( ] ``` - callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object + callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object (e.g., Langchain Chain) used for response generation. It should accept a list of message dictionaries (as described above) and return a dictionary with the following structure: @@ -202,6 +204,7 @@ def batch_generate_messages( predicted_messages.append(message) return pd.DataFrame(predicted_messages) + def save_df_to_csv(df: pd.DataFrame, dir_path: str, filename: str) -> None: """Saves a pandas DataFrame to directory as a CSV file. @@ -233,7 +236,9 @@ def prepare_metrics(metrics: List[str]) -> List[Any]: *module_path, metric_name = metric.removeprefix("custom:").split(".") metrics_evaluation.append( __import__(".".join(module_path), fromlist=[metric_name]).__dict__[ - metric_name]) + metric_name + ] + ) else: metrics_evaluation.append(metric) return metrics_evaluation diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py index 7ef2475287e..f84da947a46 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py @@ -16,14 +16,13 @@ import logging from typing import Any, Dict, Iterator -import google -import vertexai -from langchain_google_community.vertex_rank import VertexAIRank -from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings - from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template from app.patterns.custom_rag_qa.vector_store import get_vector_store from app.utils.output_types import OnChatModelStreamEvent, OnToolEndEvent, custom_chain +import google +from langchain_google_community.vertex_rank import VertexAIRank +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +import vertexai # Configuration EMBEDDING_MODEL = "text-embedding-004" @@ -52,7 +51,9 @@ @custom_chain -def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]: +def chain( + input: Dict[str, Any], **kwargs: Any +) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]: """ Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events and invoke interface and OpenTelemetry tracing. @@ -69,6 +70,6 @@ def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnC # Stream LLM response for chunk in response_chain.stream( - input={"messages": input["messages"], "relevant_documents": ranked_docs} + input={"messages": input["messages"], "relevant_documents": ranked_docs} ): - yield OnChatModelStreamEvent(data={"chunk": chunk}) \ No newline at end of file + yield OnChatModelStreamEvent(data={"chunk": chunk}) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py index bd2a7482c65..d91b9fb76d1 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py @@ -14,14 +14,22 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -query_rewrite_template = ChatPromptTemplate.from_messages([ - ("system", "Rewrite a query to a semantic search engine using the current conversation. " - "Provide only the rewritten query as output."), - MessagesPlaceholder(variable_name="messages") -]) +query_rewrite_template = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Rewrite a query to a semantic search engine using the current conversation. " + "Provide only the rewritten query as output.", + ), + MessagesPlaceholder(variable_name="messages"), + ] +) -rag_template = ChatPromptTemplate.from_messages([ - ("system", """You are an AI assistant for question-answering tasks. Follow these guidelines: +rag_template = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are an AI assistant for question-answering tasks. Follow these guidelines: 1. Use only the provided context to answer the question. 2. Give clear, accurate responses based on the information available. 3. If the context is insufficient, state: "I don't have enough information to answer this question." @@ -39,6 +47,9 @@ {{ doc.page_content | safe }} {% endfor %} -"""), - MessagesPlaceholder(variable_name="messages") -], template_format="jinja2") \ No newline at end of file +""", + ), + MessagesPlaceholder(variable_name="messages"), + ], + template_format="jinja2", +) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py index 1738ed8cba7..017d1383a10 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py @@ -40,8 +40,8 @@ def load_and_split_documents(url: str) -> List[Document]: def get_vector_store( - embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL - ) -> SKLearnVectorStore: + embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL +) -> SKLearnVectorStore: """Get or create a vector store.""" vector_store = SKLearnVectorStore(embedding=embedding, persist_path=persist_path) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py index 507b65f4c00..fced6df507a 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py @@ -30,27 +30,30 @@ def search(query: str) -> str: return "It's 60 degrees and foggy." return "It's 90 degrees and sunny." + tools = [search] # 2. Set up the language model llm = ChatVertexAI( - model="gemini-1.5-pro-001", - temperature=0, - max_tokens=1024, - streaming=True + model="gemini-1.5-pro-001", temperature=0, max_tokens=1024, streaming=True ).bind_tools(tools) + # 3. Define workflow components def should_continue(state: MessagesState) -> str: """Determines whether to use tools or end the conversation.""" - last_message = state['messages'][-1] - return "tools" if last_message.tool_calls else END # type: ignore[union-attr] + last_message = state["messages"][-1] + return "tools" if last_message.tool_calls else END # type: ignore[union-attr] -async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, BaseMessage]: + +async def call_model( + state: MessagesState, config: RunnableConfig +) -> Dict[str, BaseMessage]: """Calls the language model and returns the response.""" - response = llm.invoke(state['messages'], config) + response = llm.invoke(state["messages"], config) return {"messages": response} + # 4. Create the workflow graph workflow = StateGraph(MessagesState) workflow.add_node("agent", call_model) @@ -59,7 +62,7 @@ async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, # 5. Define graph edges workflow.add_conditional_edges("agent", should_continue) -workflow.add_edge("tools", 'agent') +workflow.add_edge("tools", "agent") # 6. Compile the workflow -chain = workflow.compile() \ No newline at end of file +chain = workflow.compile() diff --git a/gemini/sample-apps/conversational-genai-app-template/app/server.py b/gemini/sample-apps/conversational-genai-app-template/app/server.py index 63ce525fda1..19e1b004580 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/server.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/server.py @@ -15,22 +15,20 @@ import json import logging import os -import uuid from typing import AsyncGenerator +import uuid +# ruff: noqa: I001 +## Import the chain to be used +from app.chain import chain +from app.utils.input_types import Feedback, Input, InputChat, default_serialization +from app.utils.output_types import EndEvent, Event +from app.utils.tracing import CloudTraceLoggingSpanExporter from fastapi import FastAPI from fastapi.responses import RedirectResponse, StreamingResponse from google.cloud import logging as gcp_logging from traceloop.sdk import Instruments, Traceloop -from app.utils.input_types import Feedback, Input, InputChat, default_serialization -from app.utils.output_types import EndEvent, Event -from app.utils.tracing import CloudTraceLoggingSpanExporter - -# ruff: noqa: I001 -## Import the chain to be used -from app.chain import chain - # Or choose one of the following pattern chains to test by uncommenting it: # Custom RAG QA @@ -40,8 +38,13 @@ # from app.patterns.langgraph_dummy_agent.chain import chain # The events that are supported by the UI Frontend -SUPPORTED_EVENTS = ["on_tool_start", "on_tool_end", "on_retriever_start", - "on_retriever_end", "on_chat_model_stream"] +SUPPORTED_EVENTS = [ + "on_tool_start", + "on_tool_end", + "on_retriever_start", + "on_retriever_end", + "on_chat_model_stream", +] # Initialize FastAPI app and logging app = FastAPI() @@ -63,19 +66,20 @@ async def stream_event_response(input_chat: InputChat) -> AsyncGenerator[str, None]: run_id = uuid.uuid4() input_dict = input_chat.model_dump() - Traceloop.set_association_properties({ - "log_type": "tracing", - "run_id": str(run_id), - "user_id": input_dict["user_id"], - "session_id": input_dict["session_id"], - "commit_sha": os.environ.get("COMMIT_SHA", "None") - }) + Traceloop.set_association_properties( + { + "log_type": "tracing", + "run_id": str(run_id), + "user_id": input_dict["user_id"], + "session_id": input_dict["session_id"], + "commit_sha": os.environ.get("COMMIT_SHA", "None"), + } + ) yield json.dumps( - Event( - event="metadata", - data={"run_id": str(run_id)} - ), default=default_serialization) + "\n" + Event(event="metadata", data={"run_id": str(run_id)}), + default=default_serialization, + ) + "\n" async for data in chain.astream_events(input_dict, version="v2"): if data["event"] in SUPPORTED_EVENTS: @@ -97,8 +101,9 @@ async def collect_feedback(feedback_dict: Feedback) -> None: @app.post("/stream_events") async def stream_chat_events(request: Input) -> StreamingResponse: - return StreamingResponse(stream_event_response(input_chat=request.input), - media_type="text/event-stream") + return StreamingResponse( + stream_event_response(input_chat=request.input), media_type="text/event-stream" + ) # Main execution diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py index 0434fb09e92..b314474a6b3 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py @@ -20,9 +20,9 @@ class InputChat(BaseModel): """Represents the input for a chat session.""" + messages: List[Union[HumanMessage, AIMessage]] = Field( - ..., - description="The chat messages representing the current conversation." + ..., description="The chat messages representing the current conversation." ) user_id: str = "" session_id: str = "" @@ -30,16 +30,17 @@ class InputChat(BaseModel): class Input(BaseModel): """Wrapper class for InputChat.""" + input: InputChat class Feedback(BaseModel): """Represents feedback for a conversation.""" + score: Union[int, float] text: Optional[str] = None run_id: str - log_type: Literal['feedback'] = 'feedback' - + log_type: Literal["feedback"] = "feedback" def default_serialization(obj: Any) -> Any: diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py index 21b26914007..1d3bf6066a7 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py @@ -12,18 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid from functools import wraps from types import GeneratorType -from typing import ( - Any, - AsyncGenerator, - Callable, - Dict, - List, - Literal, - Union, -) +from typing import Any, AsyncGenerator, Callable, Dict, List, Literal, Union +import uuid from langchain_core.documents import Document from langchain_core.messages import AIMessage, AIMessageChunk @@ -112,8 +104,7 @@ def invoke(self, *args: Any, **kwargs: Any) -> AIMessage: elif isinstance(event, OnToolEndEvent): tool_calls.append(event.data.model_dump()) return AIMessage( - content=response_content, - additional_kwargs={"tool_calls_data": tool_calls} + content=response_content, additional_kwargs={"tool_calls_data": tool_calls} ) def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -125,8 +116,9 @@ def custom_chain(func: Callable) -> CustomChain: """ Decorator function that wraps a callable in a CustomChain instance. """ + @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) - return CustomChain(wrapper) \ No newline at end of file + return CustomChain(wrapper) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py index 63fd4293fb8..c770ae836f6 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py @@ -27,17 +27,18 @@ class CloudTraceLoggingSpanExporter(CloudTraceSpanExporter): """ An extended version of CloudTraceSpanExporter that logs span data to Google Cloud Logging and handles large attribute values by storing them in Google Cloud Storage. - + This class helps bypass the 256 character limit of Cloud Trace for attribute values by leveraging Cloud Logging (which has a 256KB limit) and Cloud Storage for larger payloads. """ + def __init__( self, logging_client: Optional[gcp_logging.Client] = None, storage_client: Optional[storage.Client] = None, bucket_name: Optional[str] = None, debug: bool = False, - **kwargs: Any + **kwargs: Any, ) -> None: """ Initialize the exporter with Google Cloud clients and configuration. @@ -50,14 +51,15 @@ def __init__( """ super().__init__(**kwargs) self.debug = debug - self.logging_client = logging_client or gcp_logging.Client(project=self.project_id) + self.logging_client = logging_client or gcp_logging.Client( + project=self.project_id + ) self.logger = self.logging_client.logger(__name__) self.storage_client = storage_client or storage.Client(project=self.project_id) self.bucket_name = bucket_name or f"{self.project_id}-logs-data" self._ensure_bucket_exists() self.bucket = self.storage_client.bucket(self.bucket_name) - def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: """ Export the spans to Google Cloud Logging and Cloud Trace. @@ -67,13 +69,13 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: """ for span in spans: span_context = span.get_span_context() - trace_id = format(span_context.trace_id, 'x') - span_id = format(span_context.span_id, 'x') + trace_id = format(span_context.trace_id, "x") + span_id = format(span_context.span_id, "x") span_dict = json.loads(span.to_json()) span_dict["trace"] = f"projects/{self.project_id}/traces/{trace_id}" span_dict["span_id"] = span_id - + span_dict = self._process_large_attributes( span_dict=span_dict, span_id=span_id ) @@ -93,7 +95,6 @@ def _ensure_bucket_exists(self) -> None: logging.info(f"Bucket {self.bucket_name} not detected. Creating it now.") self.storage_client.create_bucket(self.bucket_name) - def store_in_gcs(self, content: str, span_id: str) -> str: """ Initiate storing large content in Google Cloud Storage/ @@ -105,15 +106,14 @@ def store_in_gcs(self, content: str, span_id: str) -> str: blob_name = f"spans/{span_id}.json" blob = self.bucket.blob(blob_name) - blob.upload_from_string(content, 'application/json') + blob.upload_from_string(content, "application/json") return f"gs://{self.bucket_name}/{blob_name}" - def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: """ - Process large attribute values by storing them in GCS if they exceed the size + Process large attribute values by storing them in GCS if they exceed the size limit of Google Cloud Logging. - + :param span_dict: The span data dictionary :param trace_id: The trace ID :param span_id: The span ID @@ -123,14 +123,16 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: if len(json.dumps(attributes).encode()) > 255 * 1024: # 250 KB # Separate large payload from other attributes attributes_payload = { - k: v for k, v in attributes.items() + k: v + for k, v in attributes.items() if "traceloop.association.properties" not in k } attributes_retain = { - k: v for k, v in attributes.items() + k: v + for k, v in attributes.items() if "traceloop.association.properties" in k } - + # Store large payload in GCS gcs_uri = self.store_in_gcs(json.dumps(attributes_payload), span_id) attributes_retain["uri_payload"] = gcs_uri @@ -138,11 +140,11 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: f"https://storage.mtls.cloud.google.com/" f"{self.bucket_name}/spans/{span_id}.json" ) - + span_dict["attributes"] = attributes_retain logging.info( "Length of payload span above 250 KB, storing attributes in GCS " "to avoid large log entry errors" ) - - return span_dict \ No newline at end of file + + return span_dict diff --git a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb index a1639a21469..9289f4dbf56 100644 --- a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb +++ b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb @@ -236,6 +236,7 @@ "outputs": [], "source": [ "import sys\n", + "\n", "sys.path.append(\"../\")" ] }, @@ -330,7 +331,13 @@ "metadata": {}, "outputs": [], "source": [ - "SUPPORTED_EVENTS = [\"on_tool_start\", \"on_tool_end\",\"on_retriever_start\", \"on_retriever_end\", \"on_chat_model_stream\"]" + "SUPPORTED_EVENTS = [\n", + " \"on_tool_start\",\n", + " \"on_tool_end\",\n", + " \"on_retriever_start\",\n", + " \"on_retriever_end\",\n", + " \"on_chat_model_stream\",\n", + "]" ] }, { @@ -347,7 +354,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)\n" + "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)" ] }, { @@ -435,22 +442,26 @@ " return \"It's 60 degrees and foggy.\"\n", " return \"It's 90 degrees and sunny.\"\n", "\n", + "\n", "tools = [search]\n", "\n", "# 2. Set up the language model\n", "llm = llm.bind_tools(tools)\n", "\n", + "\n", "# 3. Define workflow components\n", "def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n", " \"\"\"Determines whether to use tools or end the conversation.\"\"\"\n", - " last_message = state['messages'][-1]\n", + " last_message = state[\"messages\"][-1]\n", " return \"tools\" if last_message.tool_calls else END\n", "\n", + "\n", "async def call_model(state: MessagesState, config: RunnableConfig):\n", " \"\"\"Calls the language model and returns the response.\"\"\"\n", - " response = llm.invoke(state['messages'], config)\n", + " response = llm.invoke(state[\"messages\"], config)\n", " return {\"messages\": response}\n", "\n", + "\n", "# 4. Create the workflow graph\n", "workflow = StateGraph(MessagesState)\n", "workflow.add_node(\"agent\", call_model)\n", @@ -459,7 +470,7 @@ "\n", "# 5. Define graph edges\n", "workflow.add_conditional_edges(\"agent\", should_continue)\n", - "workflow.add_edge(\"tools\", 'agent')\n", + "workflow.add_edge(\"tools\", \"agent\")\n", "\n", "# 6. Compile the workflow\n", "chain = workflow.compile()" @@ -529,7 +540,9 @@ "\n", "\n", "@custom_chain\n", - "def chain(input: Dict[str, Any], **kwargs) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n", + "def chain(\n", + " input: Dict[str, Any], **kwargs\n", + ") -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n", " \"\"\"\n", " Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events\n", " and invoke interface and OpenTelemetry tracing.\n", @@ -546,7 +559,7 @@ "\n", " # Stream LLM response\n", " for chunk in response_chain.stream(\n", - " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n", + " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n", " ):\n", " yield OnChatModelStreamEvent(data={\"chunk\": chunk})" ] @@ -703,7 +716,7 @@ "source": [ "scored_data[\"user\"] = scored_data[\"human_message\"].apply(lambda x: x[\"content\"])\n", "scored_data[\"reference\"] = scored_data[\"ai_message\"].apply(lambda x: x[\"content\"])\n", - "scored_data\n" + "scored_data" ] }, { @@ -771,7 +784,7 @@ "metadata": {}, "outputs": [], "source": [ - "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}\n" + "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}" ] }, { @@ -791,10 +804,15 @@ "metadata": {}, "outputs": [], "source": [ - "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n", + "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n", "\n", "metrics = [custom_faithfulness_metric]\n", - "eval_task = EvalTask(dataset=scored_data, metrics=metrics, experiment=experiment_name, metric_column_mapping={\"user\":\"prompt\"} )\n", + "eval_task = EvalTask(\n", + " dataset=scored_data,\n", + " metrics=metrics,\n", + " experiment=experiment_name,\n", + " metric_column_mapping={\"user\": \"prompt\"},\n", + ")\n", "eval_result = eval_task.evaluate()" ] }, diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py index 9c61d3cddfa..90b8b2b0bae 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py @@ -28,7 +28,6 @@ class SideBar: - def __init__(self, st) -> None: self.st = st @@ -36,28 +35,36 @@ def init_side_bar(self): with self.st.sidebar: self.url_input_field = self.st.text_input( label="Service URL", - value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL) + value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL), ) self.should_authenticate_request = self.st.checkbox( label="Authenticate request", value=False, - help="If checked, any request to the server will contain an" + help="If checked, any request to the server will contain an" "Identity token to allow authentication. " "See the Cloud Run documentation to know more about authentication:" - "https://cloud.google.com/run/docs/authenticating/service-to-service" + "https://cloud.google.com/run/docs/authenticating/service-to-service", ) col1, col2, col3 = self.st.columns(3) with col1: if self.st.button("+ New chat"): - if len(self.st.session_state.user_chats[self.st.session_state['session_id']][ - "messages"]) > 0: + if ( + len( + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ]["messages"] + ) + > 0 + ): self.st.session_state.run_id = None - self.st.session_state['session_id'] = str(uuid.uuid4()) + self.st.session_state["session_id"] = str(uuid.uuid4()) self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) - self.st.session_state.user_chats[self.st.session_state['session_id']] = { + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ] = { "title": EMPTY_CHAT_NAME, "messages": [], } @@ -66,16 +73,20 @@ def init_side_bar(self): if self.st.button("Delete chat"): self.st.session_state.run_id = None self.st.session_state.session_db.clear() - self.st.session_state.user_chats.pop(self.st.session_state['session_id']) + self.st.session_state.user_chats.pop( + self.st.session_state["session_id"] + ) if len(self.st.session_state.user_chats) > 0: chat_id = list(self.st.session_state.user_chats.keys())[0] - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) else: - self.st.session_state['session_id'] = str(uuid.uuid4()) - self.st.session_state.user_chats[self.st.session_state['session_id']] = { + self.st.session_state["session_id"] = str(uuid.uuid4()) + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ] = { "title": EMPTY_CHAT_NAME, "messages": [], } @@ -84,47 +95,54 @@ def init_side_bar(self): save_chat(self.st) self.st.subheader("Recent") # Style the heading - + all_chats = list(reversed(self.st.session_state.user_chats.items())) for chat_id, chat in all_chats[:NUM_CHAT_IN_RECENT]: if self.st.button(chat["title"], key=chat_id): self.st.session_state.run_id = None - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) with self.st.expander("Other chats"): for chat_id, chat in all_chats[NUM_CHAT_IN_RECENT:]: if self.st.button(chat["title"], key=chat_id): self.st.session_state.run_id = None - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) self.st.divider() self.st.header("Upload files from local") bucket_name = self.st.text_input( label="GCS Bucket for upload", - value=os.environ.get("BUCKET_NAME","gs://your-bucket-name") + value=os.environ.get("BUCKET_NAME", "gs://your-bucket-name"), ) - if 'checkbox_state' not in self.st.session_state: + if "checkbox_state" not in self.st.session_state: self.st.session_state.checkbox_state = True - + self.st.session_state.checkbox_state = self.st.checkbox( - "Upload to GCS first (suggested)", - value=True, - help=HELP_GCS_CHECKBOX + "Upload to GCS first (suggested)", value=True, help=HELP_GCS_CHECKBOX ) self.uploaded_files = self.st.file_uploader( - label="Send files from local", accept_multiple_files=True, + label="Send files from local", + accept_multiple_files=True, key=f"uploader_images_{self.st.session_state.uploader_key}", type=[ - "png", "jpg", "jpeg", "txt", "docx", - "pdf", "rtf", "csv", "tsv", "xlsx" - ] + "png", + "jpg", + "jpeg", + "txt", + "docx", + "pdf", + "rtf", + "csv", + "tsv", + "xlsx", + ], ) if self.uploaded_files and self.st.session_state.checkbox_state: upload_files_to_gcs(self.st, bucket_name, self.uploaded_files) @@ -136,7 +154,7 @@ def init_side_bar(self): "GCS uris (comma-separated)", value=self.st.session_state["gcs_uris_to_be_sent"], key=f"upload_text_area_{self.st.session_state.uploader_key}", - help=HELP_MESSAGE_MULTIMODALITY + help=HELP_MESSAGE_MULTIMODALITY, ) - - self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}") \ No newline at end of file + + self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}") diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py index 11cad777de2..12947fac963 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import json import uuid -from functools import partial from langchain_core.messages import HumanMessage from side_bar import SideBar +import streamlit as st from streamlit_feedback import streamlit_feedback from style.app_markdown import markdown_str from utils.local_chat_history import LocalChatMessageHistory @@ -25,8 +26,6 @@ from utils.multimodal_utils import format_content, get_parts_from_files from utils.stream_handler import Client, StreamHandler, get_chain_response -import streamlit as st - USER = "my_user" EMPTY_CHAT_NAME = "Empty chat" @@ -34,14 +33,14 @@ page_title="Playground", layout="wide", initial_sidebar_state="auto", - menu_items=None + menu_items=None, ) st.title("Playground") st.markdown(markdown_str, unsafe_allow_html=True) # First time Init of session variables if "user_chats" not in st.session_state: - st.session_state['session_id'] = str(uuid.uuid4()) + st.session_state["session_id"] = str(uuid.uuid4()) st.session_state.uploader_key = 0 st.session_state.run_id = None st.session_state.user_id = USER @@ -49,37 +48,41 @@ st.session_state["gcs_uris_to_be_sent"] = "" st.session_state.modified_prompt = None st.session_state.session_db = LocalChatMessageHistory( - session_id=st.session_state['session_id'], - user_id=st.session_state['user_id'], + session_id=st.session_state["session_id"], + user_id=st.session_state["user_id"], ) st.session_state.user_chats = st.session_state.session_db.get_all_conversations() - st.session_state.user_chats[st.session_state['session_id']] = { + st.session_state.user_chats[st.session_state["session_id"]] = { "title": EMPTY_CHAT_NAME, "messages": [], - } + } side_bar = SideBar(st=st) side_bar.init_side_bar() -client = Client(url=side_bar.url_input_field, authenticate_request=side_bar.should_authenticate_request) +client = Client( + url=side_bar.url_input_field, + authenticate_request=side_bar.should_authenticate_request, +) # Write all messages of current conversation -messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] +messages = st.session_state.user_chats[st.session_state["session_id"]]["messages"] for i, message in enumerate(messages): with st.chat_message(message["type"]): if message["type"] == "ai": - if message.get("tool_calls") and len(message.get("tool_calls")) > 0: - tool_expander = st.expander( - label="Tool Calls:", - expanded=False) + tool_expander = st.expander(label="Tool Calls:", expanded=False) with tool_expander: for index, tool_call in enumerate(message["tool_calls"]): # ruff: noqa: E501 - tool_call_output = message["additional_kwargs"]["tool_calls_outputs"][index] - msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \ - f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \ - f"\n\n**output:**\n " \ - f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```" + tool_call_output = message["additional_kwargs"][ + "tool_calls_outputs" + ][index] + msg = ( + f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" + f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" + f"\n\n**output:**\n " + f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```" + ) st.markdown(msg, unsafe_allow_html=True) st.markdown(format_content(message["content"]), unsafe_allow_html=True) @@ -88,29 +91,25 @@ refresh_button = f"{i}_refresh" delete_button = f"{i}_delete" content = message["content"] - - if isinstance(message["content"],list): + + if isinstance(message["content"], list): content = message["content"][-1]["text"] with col1: - st.button( - label="✎", - key=edit_button, - type="primary" - ) + st.button(label="✎", key=edit_button, type="primary") if message["type"] == "human": with col2: st.button( label="⟳", key=refresh_button, type="primary", - on_click=partial(MessageEditing.refresh_message, st, i, content) + on_click=partial(MessageEditing.refresh_message, st, i, content), ) with col3: st.button( label="X", key=delete_button, type="primary", - on_click=partial(MessageEditing.delete_message, st, i) + on_click=partial(MessageEditing.delete_message, st, i), ) if st.session_state[edit_button]: @@ -118,55 +117,52 @@ "Edit your message:", value=content, key=f"edit_box_{i}", - on_change=partial(MessageEditing.edit_message, st, i, message["type"])) + on_change=partial(MessageEditing.edit_message, st, i, message["type"]), + ) # Handle new (or modified) user prompt and response prompt = st.chat_input() if prompt is None: prompt = st.session_state.modified_prompt - + if prompt: st.session_state.modified_prompt = None parts = get_parts_from_files( upload_gcs_checkbox=st.session_state.checkbox_state, - uploaded_files=side_bar.uploaded_files, - gcs_uris=side_bar.gcs_uris + uploaded_files=side_bar.uploaded_files, + gcs_uris=side_bar.gcs_uris, ) st.session_state["gcs_uris_to_be_sent"] = "" - parts.append( - { - "type": "text", - "text": prompt - } + parts.append({"type": "text", "text": prompt}) + st.session_state.user_chats[st.session_state["session_id"]]["messages"].append( + HumanMessage(content=parts).model_dump() ) - st.session_state.user_chats[st.session_state['session_id']]["messages"].append( - HumanMessage(content=parts).model_dump()) human_message = st.chat_message("human") with human_message: existing_user_input = format_content(parts) user_input = st.markdown(existing_user_input, unsafe_allow_html=True) - + ai_message = st.chat_message("ai") with ai_message: status = st.status("Generating answer🤖") stream_handler = StreamHandler(st=st) - get_chain_response( - st=st, - client=client, - stream_handler=stream_handler - ) + get_chain_response(st=st, client=client, stream_handler=stream_handler) status.update(label="Finished!", state="complete", expanded=False) - if st.session_state.user_chats[st.session_state['session_id']][ - "title"] == EMPTY_CHAT_NAME: + if ( + st.session_state.user_chats[st.session_state["session_id"]]["title"] + == EMPTY_CHAT_NAME + ): st.session_state.session_db.set_title( - st.session_state.user_chats[st.session_state['session_id']] + st.session_state.user_chats[st.session_state["session_id"]] ) - st.session_state.session_db.upsert_session(st.session_state.user_chats[st.session_state['session_id']]) + st.session_state.session_db.upsert_session( + st.session_state.user_chats[st.session_state["session_id"]] + ) if len(parts) > 1: st.session_state.uploader_key += 1 st.rerun() @@ -175,12 +171,10 @@ feedback = streamlit_feedback( feedback_type="faces", optional_text_label="[Optional] Please provide an explanation", - key=f"feedback-{st.session_state.run_id}" + key=f"feedback-{st.session_state.run_id}", ) if feedback is not None: client.log_feedback( feedback_dict=feedback, run_id=st.session_state.run_id, ) - - diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py index b91a810e9c3..b1332580450 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py @@ -34,4 +34,4 @@ color: !important; } -""" \ No newline at end of file +""" diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py index 6e3967ccd39..ebbe0110ed7 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from datetime import datetime +import os -import yaml from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import HumanMessage from utils.title_summary import chain_title +import yaml class LocalChatMessageHistory(BaseChatMessageHistory): def __init__( - self, - user_id: str, - session_id: str = "default", - base_dir: str = ".streamlit_chats" + self, + user_id: str, + session_id: str = "default", + base_dir: str = ".streamlit_chats", ) -> None: self.user_id = user_id self.session_id = session_id @@ -43,12 +43,13 @@ def get_session(self, session_id): def get_all_conversations(self): conversations = {} for filename in os.listdir(self.user_dir): - if filename.endswith('.yaml'): + if filename.endswith(".yaml"): file_path = os.path.join(self.user_dir, filename) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: conversation = yaml.safe_load(f) if not isinstance(conversation, list) or len(conversation) > 1: - raise ValueError(f"""Invalid format in {file_path}. + raise ValueError( + f"""Invalid format in {file_path}. YAML file can only contain one conversation with the following structure. - messages: @@ -61,17 +62,18 @@ def get_all_conversations(self): conversation["title"] = filename conversations[filename[:-5]] = conversation return dict( - sorted(conversations.items(), key=lambda x: x[1].get('update_time', ''))) + sorted(conversations.items(), key=lambda x: x[1].get("update_time", "")) + ) def upsert_session(self, session) -> None: - session['update_time'] = datetime.now().isoformat() - with open(self.session_file, 'w') as f: + session["update_time"] = datetime.now().isoformat() + with open(self.session_file, "w") as f: yaml.dump( [session], f, allow_unicode=True, default_flow_style=False, - encoding='utf-8' + encoding="utf-8", ) def set_title(self, session) -> None: @@ -100,4 +102,4 @@ def set_title(self, session) -> None: def clear(self) -> None: if os.path.exists(self.session_file): - os.remove(self.session_file) \ No newline at end of file + os.remove(self.session_file) diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py index e696714657d..7b829d67d26 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py @@ -12,27 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -class MessageEditing: +class MessageEditing: @staticmethod def edit_message(st, button_idx, message_type): button_id = f"edit_box_{button_idx}" if message_type == "human": - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']][ - "messages"] = messages[:button_idx] + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] st.session_state.modified_prompt = st.session_state[button_id] else: - st.session_state.user_chats[st.session_state['session_id']]["messages"][ - button_idx]["content"] = st.session_state[button_id] - + st.session_state.user_chats[st.session_state["session_id"]]["messages"][ + button_idx + ]["content"] = st.session_state[button_id] + @staticmethod def refresh_message(st, button_idx, content): - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx] + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] st.session_state.modified_prompt = content @staticmethod def delete_message(st, button_idx): - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx] \ No newline at end of file + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py index 94d35231058..4c1c8a49738 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py @@ -17,12 +17,16 @@ from google.cloud import storage -HELP_MESSAGE_MULTIMODALITY = "To ensure Gemini models can access the URIs you " \ - "provide, store all URIs in buckets within the same " \ - "GCP Project that Gemini uses." +HELP_MESSAGE_MULTIMODALITY = ( + "To ensure Gemini models can access the URIs you " + "provide, store all URIs in buckets within the same " + "GCP Project that Gemini uses." +) -HELP_GCS_CHECKBOX = "Enabling GCS upload will increase app performance by avoiding to" \ - " pass large byte strings to the model" +HELP_GCS_CHECKBOX = ( + "Enabling GCS upload will increase app performance by avoiding to" + " pass large byte strings to the model" +) def format_content(content): @@ -41,9 +45,12 @@ def format_content(content): if part["type"] == "image_url": image_url = part["image_url"]["url"] image_markdown = f'' - markdown = markdown + f""" + markdown = ( + markdown + + f""" - {image_markdown} """ + ) if part["type"] == "media": # Local other media if "data" in part: @@ -54,19 +61,27 @@ def format_content(content): if "image" in part["mime_type"]: image_url = gs_uri_to_https_url(part["file_uri"]) image_markdown = f'' - markdown = markdown + f""" + markdown = ( + markdown + + f""" - {image_markdown} """ + ) # GCS other media else: - image_url = gs_uri_to_https_url(part['file_uri']) - markdown = markdown + f"- Remote media: " \ - f"[{part['file_uri']}]({image_url})\n" - markdown = markdown + f""" + image_url = gs_uri_to_https_url(part["file_uri"]) + markdown = ( + markdown + f"- Remote media: " + f"[{part['file_uri']}]({image_url})\n" + ) + markdown = ( + markdown + + f""" {text}""" + ) return markdown - + def get_gcs_blob_mime_type(gcs_uri): """Fetches the MIME type (content type) of a Google Cloud Storage blob. @@ -103,17 +118,17 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris): content = { "type": "image_url", "image_url": { - "url": f"data:{uploaded_file.type};base64," \ - f"{base64.b64encode(im_bytes).decode('utf-8')}" + "url": f"data:{uploaded_file.type};base64," + f"{base64.b64encode(im_bytes).decode('utf-8')}" }, "file_name": uploaded_file.name, } else: content = { "type": "media", - "data": base64.b64encode(im_bytes).decode('utf-8'), + "data": base64.b64encode(im_bytes).decode("utf-8"), "file_name": uploaded_file.name, - "mime_type": uploaded_file.type + "mime_type": uploaded_file.type, } parts.append(content) @@ -122,11 +137,12 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris): content = { "type": "media", "file_uri": uri, - "mime_type": get_gcs_blob_mime_type(uri) + "mime_type": get_gcs_blob_mime_type(uri), } parts.append(content) return parts + def upload_bytes_to_gcs(bucket_name, blob_name, file_bytes, content_type=None): """Uploads a bytes object to Google Cloud Storage and returns the GCS URI. @@ -177,17 +193,17 @@ def gs_uri_to_https_url(gs_uri): def upload_files_to_gcs(st, bucket_name, files_to_upload): - bucket_name = bucket_name.replace("gs://", "") - uploaded_uris = [] - for file in files_to_upload: - if file: - file_bytes = file.read() - gcs_uri = upload_bytes_to_gcs( - bucket_name=bucket_name, - blob_name=file.name, - file_bytes=file_bytes, - content_type=file.type - ) - uploaded_uris.append(gcs_uri) - st.session_state.uploader_key += 1 - st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris) + bucket_name = bucket_name.replace("gs://", "") + uploaded_uris = [] + for file in files_to_upload: + if file: + file_bytes = file.read() + gcs_uri = upload_bytes_to_gcs( + bucket_name=bucket_name, + blob_name=file.name, + file_bytes=file_bytes, + content_type=file.type, + ) + uploaded_uris.append(gcs_uri) + st.session_state.uploader_key += 1 + st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris) diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py index 9c560687152..bfe0e07c12d 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py @@ -17,12 +17,11 @@ from urllib.parse import urljoin import google.auth +from google.auth.exceptions import DefaultCredentialsError import google.auth.transport.requests import google.oauth2.id_token -import requests -from google.auth.exceptions import DefaultCredentialsError from langchain_core.messages import AIMessage - +import requests import streamlit as st @@ -78,25 +77,24 @@ def log_feedback(self, feedback_dict, run_id): "Content-Type": "application/json", } if self.authenticate_request: - headers["Authorization"] = f"Bearer {self.id_token}" + headers["Authorization"] = f"Bearer {self.id_token}" requests.post(url, data=json.dumps(feedback_dict), headers=headers) - def stream_events(self, data: Dict[str, Any]) -> Generator[ - Dict[str, Any], None, None]: + def stream_events( + self, data: Dict[str, Any] + ) -> Generator[Dict[str, Any], None, None]: """Stream events from the server, yielding parsed event data.""" - headers = { - "Content-Type": "application/json", - "Accept": "text/event-stream" - } + headers = {"Content-Type": "application/json", "Accept": "text/event-stream"} if self.authenticate_request: - headers["Authorization"] = f"Bearer {self.id_token}" + headers["Authorization"] = f"Bearer {self.id_token}" - with requests.post(self.url, json={"input": data}, headers=headers, - stream=True) as response: + with requests.post( + self.url, json={"input": data}, headers=headers, stream=True + ) as response: for line in response.iter_lines(): if line: try: - event = json.loads(line.decode('utf-8')) + event = json.loads(line.decode("utf-8")) # print(event) yield event except json.JSONDecodeError: @@ -125,9 +123,6 @@ def new_status(self, status_update: str) -> None: self.tool_expander.markdown(status_update) - - - class EventProcessor: """Processes events from the stream and updates the UI accordingly.""" @@ -144,14 +139,14 @@ def __init__(self, st, client, stream_handler): def process_events(self): """Process events from the stream, handling each event type appropriately.""" - messages = \ - self.st.session_state.user_chats[self.st.session_state['session_id']][ - "messages"] + messages = self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ]["messages"] stream = self.client.stream_events( data={ "messages": messages, - "user_id": self.st.session_state['user_id'], - "session_id": self.st.session_state['session_id'], + "user_id": self.st.session_state["user_id"], + "session_id": self.st.session_state["session_id"], } ) @@ -173,11 +168,13 @@ def process_events(self): def handle_metadata(self, event: Dict[str, Any]) -> None: """Handle metadata events.""" - self.current_run_id = event['data'].get('run_id') + self.current_run_id = event["data"].get("run_id") def handle_tool_start(self, event: Dict[str, Any]) -> None: """Handle the start of a tool or retriever execution.""" - msg = f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`" + msg = ( + f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`" + ) self.stream_handler.new_status(msg) def handle_tool_end(self, event: Dict[str, Any]) -> None: @@ -185,24 +182,26 @@ def handle_tool_end(self, event: Dict[str, Any]) -> None: data = event["data"] # Support tool events if isinstance(data["output"], dict): - tool_id = data["output"].get('tool_call_id', None) - tool_name = data["output"].get('name', 'Unknown Tool') + tool_id = data["output"].get("tool_call_id", None) + tool_name = data["output"].get("name", "Unknown Tool") # Support retriever events else: tool_id = event.get("id", "None") tool_name = event.get("name", event["event"]) - tool_input = data['input'] - tool_output = data['output'] + tool_input = data["input"] + tool_output = data["output"] tool_call = {"id": tool_id, "name": tool_name, "args": tool_input} self.tool_calls.append(tool_call) - tool_call_outputs = {"id":tool_id, "output":tool_output} + tool_call_outputs = {"id": tool_id, "output": tool_output} self.tool_calls_outputs.append(tool_call_outputs) - msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \ - f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \ - f"\n\n**output:**\n " \ - f"```\n{json.dumps(tool_output, indent=2)}\n```" + msg = ( + f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" + f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" + f"\n\n**output:**\n " + f"```\n{json.dumps(tool_output, indent=2)}\n```" + ) self.stream_handler.new_status(msg) def handle_chat_model_stream(self, event: Dict[str, Any]) -> None: @@ -211,7 +210,7 @@ def handle_chat_model_stream(self, event: Dict[str, Any]) -> None: content = data["chunk"]["content"] self.additional_kwargs = { **self.additional_kwargs, - **data["chunk"]["additional_kwargs"] + **data["chunk"]["additional_kwargs"], } if content and len(content.strip()) > 0: self.final_content += content @@ -225,12 +224,13 @@ def handle_end(self, event: Dict[str, Any]) -> None: content=self.final_content, tool_calls=self.tool_calls, id=self.current_run_id, - additional_kwargs=additional_kwargs + additional_kwargs=additional_kwargs, ).model_dump() - session = self.st.session_state['session_id'] + session = self.st.session_state["session_id"] self.st.session_state.user_chats[session]["messages"].append(final_message) self.st.session_state.run_id = self.current_run_id + def get_chain_response(st, client, stream_handler): """Process the chain response update the Streamlit UI. diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py index a34efa51abd..443263c1a4f 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py @@ -63,4 +63,4 @@ MessagesPlaceholder(variable_name="messages"), ]) -chain_title = title_template | llm \ No newline at end of file +chain_title = title_template | llm diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py index 28f40eee878..42cc057a8c5 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py @@ -20,8 +20,6 @@ SAVED_CHAT_PATH = str(os.getcwd()) + "/.saved_chats" - - def preprocess_text(text): if text[0] == "\n": text = text[1:] @@ -55,6 +53,6 @@ def save_chat(st): file, allow_unicode=True, default_flow_style=False, - encoding='utf-8', + encoding="utf-8", ) st.toast(f"Chat saved to path: ↓ {Path(SAVED_CHAT_PATH) / filename}") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py index 25615c9b1fd..0f94b40b7a0 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.patterns.langgraph_dummy_agent.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Langgraph agent" @@ -33,19 +32,21 @@ async def test_langgraph_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py index 951631b177b..36ace08010b 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.patterns.custom_rag_qa.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Rag QA" @@ -33,19 +32,21 @@ async def test_rag_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" - ] + ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py index 9b9da19f584..73c27f93105 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Default" @@ -33,19 +32,21 @@ async def test_default_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" - ] + ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py index c10e3c3e4d6..1c76307086f 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py @@ -18,8 +18,8 @@ import sys import threading import time -import uuid from typing import Any, Iterator +import uuid import pytest import requests @@ -35,11 +35,13 @@ HEADERS = {"Content-Type": "application/json"} + def log_output(pipe: Any, log_func: Any) -> None: """Log the output from the given pipe.""" - for line in iter(pipe.readline, ''): + for line in iter(pipe.readline, ""): log_func(line.strip()) + def start_server() -> subprocess.Popen[str]: """Start the FastAPI server using subprocess and log its output.""" command = [ @@ -58,18 +60,15 @@ def start_server() -> subprocess.Popen[str]: # Start threads to log stdout and stderr in real-time threading.Thread( - target=log_output, - args=(process.stdout, logger.info), - daemon=True + target=log_output, args=(process.stdout, logger.info), daemon=True ).start() threading.Thread( - target=log_output, - args=(process.stderr, logger.error), - daemon=True + target=log_output, args=(process.stderr, logger.error), daemon=True ).start() return process + def wait_for_server(timeout: int = 60, interval: int = 1) -> bool: """Wait for the server to be ready.""" start_time = time.time() @@ -85,6 +84,7 @@ def wait_for_server(timeout: int = 60, interval: int = 1) -> bool: logger.error(f"Server did not become ready within {timeout} seconds") return False + @pytest.fixture(scope="session") def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]: """Pytest fixture to start and stop the server for testing.""" @@ -103,6 +103,7 @@ def stop_server() -> None: request.addfinalizer(stop_server) yield server_process + def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: """Test the chat stream functionality.""" logger.info("Starting chat stream test") @@ -112,10 +113,10 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "What cooking recipes do you suggest?"} + {"role": "user", "content": "What cooking recipes do you suggest?"}, ], "user_id": "test-user", - "session_id": "test-session" + "session_id": "test-session", } } @@ -128,28 +129,35 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: logger.info(f"Received {len(events)} events") assert len(events) > 2, f"Expected more than 2 events, got {len(events)}." - assert events[0]["event"] == "metadata", f"First event should be 'metadata', " \ - f"got {events[0]['event']}" + assert events[0]["event"] == "metadata", ( + f"First event should be 'metadata', " f"got {events[0]['event']}" + ) assert "run_id" in events[0]["data"], "Missing 'run_id' in metadata" event_types = [event["event"] for event in events] assert "on_chat_model_stream" in event_types, "Missing 'on_chat_model_stream' event" - assert events[-1]["event"] == "end", f"Last event should be 'end', " \ - f"got {events[-1]['event']}" + assert events[-1]["event"] == "end", ( + f"Last event should be 'end', " f"got {events[-1]['event']}" + ) logger.info("Test completed successfully") + def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None: """Test the chat stream error handling.""" logger.info("Starting chat stream error handling test") data = {"input": [{"role": "invalid_role", "content": "Cause an error"}]} - response = requests.post(STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10) + response = requests.post( + STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10 + ) - assert response.status_code == 422, f"Expected status code 422, " \ - f"got {response.status_code}" + assert response.status_code == 422, ( + f"Expected status code 422, " f"got {response.status_code}" + ) logger.info("Error handling test completed successfully") + def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None: """ Test the feedback collection endpoint (/feedback) to ensure it properly diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py index 0e356953c45..73af9fd5f44 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py @@ -26,17 +26,17 @@ class ChatStreamUser(HttpUser): def chat_stream(self) -> None: headers = {"Content-Type": "application/json"} if os.environ.get("_ID_TOKEN"): - headers['Authorization'] = f'Bearer {os.environ["_ID_TOKEN"]}' + headers["Authorization"] = f'Bearer {os.environ["_ID_TOKEN"]}' data = { "input": { "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "Who are you?"} + {"role": "user", "content": "Who are you?"}, ], "user_id": "test-user", - "session_id": "test-session" + "session_id": "test-session", } } @@ -48,7 +48,7 @@ def chat_stream(self) -> None: json=data, catch_response=True, name="/stream_events first event", - stream=True + stream=True, ) as response: if response.status_code == 200: events = [] @@ -73,9 +73,9 @@ def chat_stream(self) -> None: response_time=total_time * 1000, # Convert to milliseconds response_length=len(json.dumps(events)), response=response, - context={} + context={}, ) else: response.failure("Unexpected response structure") else: - response.failure(f"Unexpected status code: {response.status_code}") \ No newline at end of file + response.failure(f"Unexpected status code: {response.status_code}") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py index f760cf500d6..51565628faf 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py @@ -17,13 +17,12 @@ from typing import Any from unittest.mock import patch -import pytest +from app.server import app +from app.utils.input_types import InputChat from fastapi.testclient import TestClient from httpx import AsyncClient from langchain_core.messages import HumanMessage - -from app.server import app -from app.utils.input_types import InputChat +import pytest # Set up logging logging.basicConfig(level=logging.INFO) @@ -41,7 +40,7 @@ def sample_input_chat() -> InputChat: return InputChat( user_id="test-user", session_id="test-session", - messages=[HumanMessage(content="What is the meaning of life?")] + messages=[HumanMessage(content="What is the meaning of life?")], ) @@ -62,7 +61,7 @@ class AsyncIterator: def __init__(self, seq: list) -> None: self.iter = iter(seq) - def __aiter__(self) -> 'AsyncIterator': + def __aiter__(self) -> "AsyncIterator": return self async def __anext__(self) -> Any: @@ -77,7 +76,7 @@ def mock_chain() -> Any: """ Fixture to mock the chain object used in the application. """ - with patch('app.server.chain') as mock: + with patch("app.server.chain") as mock: yield mock @@ -94,8 +93,8 @@ async def test_stream_chat_events(mock_chain: Any) -> None: "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "What cooking recipes do you suggest?"} - ] + {"role": "user", "content": "What cooking recipes do you suggest?"}, + ], } } @@ -107,8 +106,9 @@ async def test_stream_chat_events(mock_chain: Any) -> None: mock_chain.astream_events.return_value = AsyncIterator(mock_events) - with patch('uuid.uuid4', return_value=mock_uuid), \ - patch('app.server.Traceloop.set_association_properties'): + with patch("uuid.uuid4", return_value=mock_uuid), patch( + "app.server.Traceloop.set_association_properties" + ): async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/stream_events", json=input_data) diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py index 73385724109..e584ee0b704 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py @@ -14,41 +14,47 @@ from unittest.mock import Mock, patch -import pytest +from app.utils.tracing import CloudTraceLoggingSpanExporter from google.cloud import logging as gcp_logging from google.cloud import storage from opentelemetry.sdk.trace import ReadableSpan - -from app.utils.tracing import CloudTraceLoggingSpanExporter +import pytest @pytest.fixture def mock_logging_client() -> Mock: return Mock(spec=gcp_logging.Client) + @pytest.fixture def mock_storage_client() -> Mock: return Mock(spec=storage.Client) + @pytest.fixture -def exporter(mock_logging_client: Mock, mock_storage_client: Mock) -> CloudTraceLoggingSpanExporter: +def exporter( + mock_logging_client: Mock, mock_storage_client: Mock +) -> CloudTraceLoggingSpanExporter: return CloudTraceLoggingSpanExporter( project_id="test-project", logging_client=mock_logging_client, storage_client=mock_storage_client, - bucket_name="test-bucket" + bucket_name="test-bucket", ) + def test_init(exporter: CloudTraceLoggingSpanExporter) -> None: assert exporter.project_id == "test-project" assert exporter.bucket_name == "test-bucket" assert exporter.debug is False + def test_ensure_bucket_exists(exporter: CloudTraceLoggingSpanExporter) -> None: exporter.storage_client.bucket.return_value.exists.return_value = False exporter._ensure_bucket_exists() exporter.storage_client.create_bucket.assert_called_once_with("test-bucket") + def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None: span_id = "test-span-id" content = "test-content" @@ -56,26 +62,26 @@ def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None: assert uri == f"gs://test-bucket/spans/{span_id}.json" exporter.bucket.blob.assert_called_once_with(f"spans/{span_id}.json") -@patch('json.dumps') + +@patch("json.dumps") def test_process_large_attributes_small_payload( - mock_json_dumps: Mock, - exporter: CloudTraceLoggingSpanExporter + mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter ) -> None: - mock_json_dumps.return_value = 'a' * 100 # Small payload + mock_json_dumps.return_value = "a" * 100 # Small payload span_dict = {"attributes": {"key": "value"}} result = exporter._process_large_attributes(span_dict, "span-id") assert result == span_dict -@patch('json.dumps') + +@patch("json.dumps") def test_process_large_attributes_large_payload( - mock_json_dumps: Mock, - exporter: CloudTraceLoggingSpanExporter + mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter ) -> None: - mock_json_dumps.return_value = 'a' * (400 * 1024 + 1) # Large payload + mock_json_dumps.return_value = "a" * (400 * 1024 + 1) # Large payload span_dict = { "attributes": { "key1": "value1", - "traceloop.association.properties.key2": "value2" + "traceloop.association.properties.key2": "value2", } } result = exporter._process_large_attributes(span_dict, "span-id") @@ -84,16 +90,19 @@ def test_process_large_attributes_large_payload( assert "key1" not in result["attributes"] assert "traceloop.association.properties.key2" in result["attributes"] -@patch.object(CloudTraceLoggingSpanExporter, '_process_large_attributes') -def test_export(mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter) -> None: + +@patch.object(CloudTraceLoggingSpanExporter, "_process_large_attributes") +def test_export( + mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter +) -> None: mock_span = Mock(spec=ReadableSpan) mock_span.get_span_context.return_value.trace_id = 123 mock_span.get_span_context.return_value.span_id = 456 mock_span.to_json.return_value = '{"key": "value"}' - + mock_process_large_attributes.return_value = {"processed": "data"} - + exporter.export([mock_span]) - + mock_process_large_attributes.assert_called_once() exporter.logger.log_struct.assert_called_once()