diff --git a/gemini/sample-apps/conversational-genai-app-template/app/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/chain.py
index c126d460413..e2d2d7d4c1a 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/chain.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/chain.py
@@ -24,16 +24,21 @@
}
llm = ChatVertexAI(
- model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=1024,
- safety_settings=safety_settings
+ model_name="gemini-1.5-flash-001",
+ temperature=0,
+ max_output_tokens=1024,
+ safety_settings=safety_settings,
)
template = ChatPromptTemplate.from_messages(
[
- ("system", """You are a conversational bot that produce recipes for users based
- on a question."""),
- MessagesPlaceholder(variable_name="messages")
+ (
+ "system",
+ """You are a conversational bot that produce recipes for users based
+ on a question.""",
+ ),
+ MessagesPlaceholder(variable_name="messages"),
]
)
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py
index 9892381ef63..487088acdf3 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py
@@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
import glob
import logging
import os
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
from typing import Any, Callable, Dict, Iterator, List
import nest_asyncio
import pandas as pd
-import yaml
from tqdm import tqdm
+import yaml
nest_asyncio.apply()
+
def load_chats(path: str) -> List[Dict[str, Any]]:
"""
Loads a list of chats from a directory or file.
@@ -44,6 +45,7 @@ def load_chats(path: str) -> List[Dict[str, Any]]:
chats = chats + chats_in_file
return chats
+
def pairwise(iterable: List[Any]) -> Iterator[tuple[Any, Any]]:
"""Creates an iterable with tuples paired together
e.g s -> (s0, s1), (s2, s3), (s4, s5), ...
@@ -81,11 +83,9 @@ def generate_multiturn_history(df: pd.DataFrame) -> pd.DataFrame:
message = {
"human_message": human_message,
"ai_message": ai_message,
- "conversation_history": conversation_history
+ "conversation_history": conversation_history,
}
- conversation_history = conversation_history + [
- human_message, ai_message
- ]
+ conversation_history = conversation_history + [human_message, ai_message]
processed_messages.append(message)
return pd.DataFrame(processed_messages)
@@ -103,7 +103,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
Args:
row (tuple[int, Dict[str, Any]]): A tuple containing the index and a dictionary
with message data, including:
- - "conversation_history" (List[str]): Optional. List of previous
+ - "conversation_history" (List[str]): Optional. List of previous
messages
in the conversation.
- "human_message" (str): The current human message.
@@ -118,7 +118,9 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
- "response_obj" (Any): The usage metadata of the response from the callable.
"""
index, message = row
- messages = message["conversation_history"] if "conversation_history" in message else []
+ messages = (
+ message["conversation_history"] if "conversation_history" in message else []
+ )
messages.append(message["human_message"])
input_callable = {"messages": messages}
response = callable.invoke(input_callable)
@@ -130,7 +132,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
def batch_generate_messages(
messages: pd.DataFrame,
callable: Callable[[List[Dict[str, Any]]], Dict[str, Any]],
- max_workers: int = 4
+ max_workers: int = 4,
) -> pd.DataFrame:
"""Generates AI responses to user messages using a provided callable.
@@ -152,7 +154,7 @@ def batch_generate_messages(
]
```
- callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object
+ callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object
(e.g., Langchain Chain) used
for response generation. It should accept a list of message dictionaries
(as described above) and return a dictionary with the following structure:
@@ -202,6 +204,7 @@ def batch_generate_messages(
predicted_messages.append(message)
return pd.DataFrame(predicted_messages)
+
def save_df_to_csv(df: pd.DataFrame, dir_path: str, filename: str) -> None:
"""Saves a pandas DataFrame to directory as a CSV file.
@@ -233,7 +236,9 @@ def prepare_metrics(metrics: List[str]) -> List[Any]:
*module_path, metric_name = metric.removeprefix("custom:").split(".")
metrics_evaluation.append(
__import__(".".join(module_path), fromlist=[metric_name]).__dict__[
- metric_name])
+ metric_name
+ ]
+ )
else:
metrics_evaluation.append(metric)
return metrics_evaluation
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py
index 7ef2475287e..f84da947a46 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py
@@ -16,14 +16,13 @@
import logging
from typing import Any, Dict, Iterator
-import google
-import vertexai
-from langchain_google_community.vertex_rank import VertexAIRank
-from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
-
from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template
from app.patterns.custom_rag_qa.vector_store import get_vector_store
from app.utils.output_types import OnChatModelStreamEvent, OnToolEndEvent, custom_chain
+import google
+from langchain_google_community.vertex_rank import VertexAIRank
+from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
+import vertexai
# Configuration
EMBEDDING_MODEL = "text-embedding-004"
@@ -52,7 +51,9 @@
@custom_chain
-def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:
+def chain(
+ input: Dict[str, Any], **kwargs: Any
+) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:
"""
Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events
and invoke interface and OpenTelemetry tracing.
@@ -69,6 +70,6 @@ def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnC
# Stream LLM response
for chunk in response_chain.stream(
- input={"messages": input["messages"], "relevant_documents": ranked_docs}
+ input={"messages": input["messages"], "relevant_documents": ranked_docs}
):
- yield OnChatModelStreamEvent(data={"chunk": chunk})
\ No newline at end of file
+ yield OnChatModelStreamEvent(data={"chunk": chunk})
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py
index bd2a7482c65..d91b9fb76d1 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py
@@ -14,14 +14,22 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
-query_rewrite_template = ChatPromptTemplate.from_messages([
- ("system", "Rewrite a query to a semantic search engine using the current conversation. "
- "Provide only the rewritten query as output."),
- MessagesPlaceholder(variable_name="messages")
-])
+query_rewrite_template = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ "Rewrite a query to a semantic search engine using the current conversation. "
+ "Provide only the rewritten query as output.",
+ ),
+ MessagesPlaceholder(variable_name="messages"),
+ ]
+)
-rag_template = ChatPromptTemplate.from_messages([
- ("system", """You are an AI assistant for question-answering tasks. Follow these guidelines:
+rag_template = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ """You are an AI assistant for question-answering tasks. Follow these guidelines:
1. Use only the provided context to answer the question.
2. Give clear, accurate responses based on the information available.
3. If the context is insufficient, state: "I don't have enough information to answer this question."
@@ -39,6 +47,9 @@
{{ doc.page_content | safe }}
{% endfor %}
-"""),
- MessagesPlaceholder(variable_name="messages")
-], template_format="jinja2")
\ No newline at end of file
+""",
+ ),
+ MessagesPlaceholder(variable_name="messages"),
+ ],
+ template_format="jinja2",
+)
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py
index 1738ed8cba7..017d1383a10 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py
@@ -40,8 +40,8 @@ def load_and_split_documents(url: str) -> List[Document]:
def get_vector_store(
- embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL
- ) -> SKLearnVectorStore:
+ embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL
+) -> SKLearnVectorStore:
"""Get or create a vector store."""
vector_store = SKLearnVectorStore(embedding=embedding, persist_path=persist_path)
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py
index 507b65f4c00..fced6df507a 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py
@@ -30,27 +30,30 @@ def search(query: str) -> str:
return "It's 60 degrees and foggy."
return "It's 90 degrees and sunny."
+
tools = [search]
# 2. Set up the language model
llm = ChatVertexAI(
- model="gemini-1.5-pro-001",
- temperature=0,
- max_tokens=1024,
- streaming=True
+ model="gemini-1.5-pro-001", temperature=0, max_tokens=1024, streaming=True
).bind_tools(tools)
+
# 3. Define workflow components
def should_continue(state: MessagesState) -> str:
"""Determines whether to use tools or end the conversation."""
- last_message = state['messages'][-1]
- return "tools" if last_message.tool_calls else END # type: ignore[union-attr]
+ last_message = state["messages"][-1]
+ return "tools" if last_message.tool_calls else END # type: ignore[union-attr]
-async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, BaseMessage]:
+
+async def call_model(
+ state: MessagesState, config: RunnableConfig
+) -> Dict[str, BaseMessage]:
"""Calls the language model and returns the response."""
- response = llm.invoke(state['messages'], config)
+ response = llm.invoke(state["messages"], config)
return {"messages": response}
+
# 4. Create the workflow graph
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
@@ -59,7 +62,7 @@ async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str,
# 5. Define graph edges
workflow.add_conditional_edges("agent", should_continue)
-workflow.add_edge("tools", 'agent')
+workflow.add_edge("tools", "agent")
# 6. Compile the workflow
-chain = workflow.compile()
\ No newline at end of file
+chain = workflow.compile()
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/server.py b/gemini/sample-apps/conversational-genai-app-template/app/server.py
index 63ce525fda1..19e1b004580 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/server.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/server.py
@@ -15,22 +15,20 @@
import json
import logging
import os
-import uuid
from typing import AsyncGenerator
+import uuid
+# ruff: noqa: I001
+## Import the chain to be used
+from app.chain import chain
+from app.utils.input_types import Feedback, Input, InputChat, default_serialization
+from app.utils.output_types import EndEvent, Event
+from app.utils.tracing import CloudTraceLoggingSpanExporter
from fastapi import FastAPI
from fastapi.responses import RedirectResponse, StreamingResponse
from google.cloud import logging as gcp_logging
from traceloop.sdk import Instruments, Traceloop
-from app.utils.input_types import Feedback, Input, InputChat, default_serialization
-from app.utils.output_types import EndEvent, Event
-from app.utils.tracing import CloudTraceLoggingSpanExporter
-
-# ruff: noqa: I001
-## Import the chain to be used
-from app.chain import chain
-
# Or choose one of the following pattern chains to test by uncommenting it:
# Custom RAG QA
@@ -40,8 +38,13 @@
# from app.patterns.langgraph_dummy_agent.chain import chain
# The events that are supported by the UI Frontend
-SUPPORTED_EVENTS = ["on_tool_start", "on_tool_end", "on_retriever_start",
- "on_retriever_end", "on_chat_model_stream"]
+SUPPORTED_EVENTS = [
+ "on_tool_start",
+ "on_tool_end",
+ "on_retriever_start",
+ "on_retriever_end",
+ "on_chat_model_stream",
+]
# Initialize FastAPI app and logging
app = FastAPI()
@@ -63,19 +66,20 @@
async def stream_event_response(input_chat: InputChat) -> AsyncGenerator[str, None]:
run_id = uuid.uuid4()
input_dict = input_chat.model_dump()
- Traceloop.set_association_properties({
- "log_type": "tracing",
- "run_id": str(run_id),
- "user_id": input_dict["user_id"],
- "session_id": input_dict["session_id"],
- "commit_sha": os.environ.get("COMMIT_SHA", "None")
- })
+ Traceloop.set_association_properties(
+ {
+ "log_type": "tracing",
+ "run_id": str(run_id),
+ "user_id": input_dict["user_id"],
+ "session_id": input_dict["session_id"],
+ "commit_sha": os.environ.get("COMMIT_SHA", "None"),
+ }
+ )
yield json.dumps(
- Event(
- event="metadata",
- data={"run_id": str(run_id)}
- ), default=default_serialization) + "\n"
+ Event(event="metadata", data={"run_id": str(run_id)}),
+ default=default_serialization,
+ ) + "\n"
async for data in chain.astream_events(input_dict, version="v2"):
if data["event"] in SUPPORTED_EVENTS:
@@ -97,8 +101,9 @@ async def collect_feedback(feedback_dict: Feedback) -> None:
@app.post("/stream_events")
async def stream_chat_events(request: Input) -> StreamingResponse:
- return StreamingResponse(stream_event_response(input_chat=request.input),
- media_type="text/event-stream")
+ return StreamingResponse(
+ stream_event_response(input_chat=request.input), media_type="text/event-stream"
+ )
# Main execution
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py
index 0434fb09e92..b314474a6b3 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py
@@ -20,9 +20,9 @@
class InputChat(BaseModel):
"""Represents the input for a chat session."""
+
messages: List[Union[HumanMessage, AIMessage]] = Field(
- ...,
- description="The chat messages representing the current conversation."
+ ..., description="The chat messages representing the current conversation."
)
user_id: str = ""
session_id: str = ""
@@ -30,16 +30,17 @@ class InputChat(BaseModel):
class Input(BaseModel):
"""Wrapper class for InputChat."""
+
input: InputChat
class Feedback(BaseModel):
"""Represents feedback for a conversation."""
+
score: Union[int, float]
text: Optional[str] = None
run_id: str
- log_type: Literal['feedback'] = 'feedback'
-
+ log_type: Literal["feedback"] = "feedback"
def default_serialization(obj: Any) -> Any:
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py
index 21b26914007..1d3bf6066a7 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py
@@ -12,18 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import uuid
from functools import wraps
from types import GeneratorType
-from typing import (
- Any,
- AsyncGenerator,
- Callable,
- Dict,
- List,
- Literal,
- Union,
-)
+from typing import Any, AsyncGenerator, Callable, Dict, List, Literal, Union
+import uuid
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk
@@ -112,8 +104,7 @@ def invoke(self, *args: Any, **kwargs: Any) -> AIMessage:
elif isinstance(event, OnToolEndEvent):
tool_calls.append(event.data.model_dump())
return AIMessage(
- content=response_content,
- additional_kwargs={"tool_calls_data": tool_calls}
+ content=response_content, additional_kwargs={"tool_calls_data": tool_calls}
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
@@ -125,8 +116,9 @@ def custom_chain(func: Callable) -> CustomChain:
"""
Decorator function that wraps a callable in a CustomChain instance.
"""
+
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
- return CustomChain(wrapper)
\ No newline at end of file
+ return CustomChain(wrapper)
diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py
index 63fd4293fb8..c770ae836f6 100644
--- a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py
+++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py
@@ -27,17 +27,18 @@ class CloudTraceLoggingSpanExporter(CloudTraceSpanExporter):
"""
An extended version of CloudTraceSpanExporter that logs span data to Google Cloud Logging
and handles large attribute values by storing them in Google Cloud Storage.
-
+
This class helps bypass the 256 character limit of Cloud Trace for attribute values
by leveraging Cloud Logging (which has a 256KB limit) and Cloud Storage for larger payloads.
"""
+
def __init__(
self,
logging_client: Optional[gcp_logging.Client] = None,
storage_client: Optional[storage.Client] = None,
bucket_name: Optional[str] = None,
debug: bool = False,
- **kwargs: Any
+ **kwargs: Any,
) -> None:
"""
Initialize the exporter with Google Cloud clients and configuration.
@@ -50,14 +51,15 @@ def __init__(
"""
super().__init__(**kwargs)
self.debug = debug
- self.logging_client = logging_client or gcp_logging.Client(project=self.project_id)
+ self.logging_client = logging_client or gcp_logging.Client(
+ project=self.project_id
+ )
self.logger = self.logging_client.logger(__name__)
self.storage_client = storage_client or storage.Client(project=self.project_id)
self.bucket_name = bucket_name or f"{self.project_id}-logs-data"
self._ensure_bucket_exists()
self.bucket = self.storage_client.bucket(self.bucket_name)
-
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
"""
Export the spans to Google Cloud Logging and Cloud Trace.
@@ -67,13 +69,13 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
"""
for span in spans:
span_context = span.get_span_context()
- trace_id = format(span_context.trace_id, 'x')
- span_id = format(span_context.span_id, 'x')
+ trace_id = format(span_context.trace_id, "x")
+ span_id = format(span_context.span_id, "x")
span_dict = json.loads(span.to_json())
span_dict["trace"] = f"projects/{self.project_id}/traces/{trace_id}"
span_dict["span_id"] = span_id
-
+
span_dict = self._process_large_attributes(
span_dict=span_dict, span_id=span_id
)
@@ -93,7 +95,6 @@ def _ensure_bucket_exists(self) -> None:
logging.info(f"Bucket {self.bucket_name} not detected. Creating it now.")
self.storage_client.create_bucket(self.bucket_name)
-
def store_in_gcs(self, content: str, span_id: str) -> str:
"""
Initiate storing large content in Google Cloud Storage/
@@ -105,15 +106,14 @@ def store_in_gcs(self, content: str, span_id: str) -> str:
blob_name = f"spans/{span_id}.json"
blob = self.bucket.blob(blob_name)
- blob.upload_from_string(content, 'application/json')
+ blob.upload_from_string(content, "application/json")
return f"gs://{self.bucket_name}/{blob_name}"
-
def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict:
"""
- Process large attribute values by storing them in GCS if they exceed the size
+ Process large attribute values by storing them in GCS if they exceed the size
limit of Google Cloud Logging.
-
+
:param span_dict: The span data dictionary
:param trace_id: The trace ID
:param span_id: The span ID
@@ -123,14 +123,16 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict:
if len(json.dumps(attributes).encode()) > 255 * 1024: # 250 KB
# Separate large payload from other attributes
attributes_payload = {
- k: v for k, v in attributes.items()
+ k: v
+ for k, v in attributes.items()
if "traceloop.association.properties" not in k
}
attributes_retain = {
- k: v for k, v in attributes.items()
+ k: v
+ for k, v in attributes.items()
if "traceloop.association.properties" in k
}
-
+
# Store large payload in GCS
gcs_uri = self.store_in_gcs(json.dumps(attributes_payload), span_id)
attributes_retain["uri_payload"] = gcs_uri
@@ -138,11 +140,11 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict:
f"https://storage.mtls.cloud.google.com/"
f"{self.bucket_name}/spans/{span_id}.json"
)
-
+
span_dict["attributes"] = attributes_retain
logging.info(
"Length of payload span above 250 KB, storing attributes in GCS "
"to avoid large log entry errors"
)
-
- return span_dict
\ No newline at end of file
+
+ return span_dict
diff --git a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb
index a1639a21469..9289f4dbf56 100644
--- a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb
+++ b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb
@@ -236,6 +236,7 @@
"outputs": [],
"source": [
"import sys\n",
+ "\n",
"sys.path.append(\"../\")"
]
},
@@ -330,7 +331,13 @@
"metadata": {},
"outputs": [],
"source": [
- "SUPPORTED_EVENTS = [\"on_tool_start\", \"on_tool_end\",\"on_retriever_start\", \"on_retriever_end\", \"on_chat_model_stream\"]"
+ "SUPPORTED_EVENTS = [\n",
+ " \"on_tool_start\",\n",
+ " \"on_tool_end\",\n",
+ " \"on_retriever_start\",\n",
+ " \"on_retriever_end\",\n",
+ " \"on_chat_model_stream\",\n",
+ "]"
]
},
{
@@ -347,7 +354,7 @@
"metadata": {},
"outputs": [],
"source": [
- "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)\n"
+ "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)"
]
},
{
@@ -435,22 +442,26 @@
" return \"It's 60 degrees and foggy.\"\n",
" return \"It's 90 degrees and sunny.\"\n",
"\n",
+ "\n",
"tools = [search]\n",
"\n",
"# 2. Set up the language model\n",
"llm = llm.bind_tools(tools)\n",
"\n",
+ "\n",
"# 3. Define workflow components\n",
"def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n",
" \"\"\"Determines whether to use tools or end the conversation.\"\"\"\n",
- " last_message = state['messages'][-1]\n",
+ " last_message = state[\"messages\"][-1]\n",
" return \"tools\" if last_message.tool_calls else END\n",
"\n",
+ "\n",
"async def call_model(state: MessagesState, config: RunnableConfig):\n",
" \"\"\"Calls the language model and returns the response.\"\"\"\n",
- " response = llm.invoke(state['messages'], config)\n",
+ " response = llm.invoke(state[\"messages\"], config)\n",
" return {\"messages\": response}\n",
"\n",
+ "\n",
"# 4. Create the workflow graph\n",
"workflow = StateGraph(MessagesState)\n",
"workflow.add_node(\"agent\", call_model)\n",
@@ -459,7 +470,7 @@
"\n",
"# 5. Define graph edges\n",
"workflow.add_conditional_edges(\"agent\", should_continue)\n",
- "workflow.add_edge(\"tools\", 'agent')\n",
+ "workflow.add_edge(\"tools\", \"agent\")\n",
"\n",
"# 6. Compile the workflow\n",
"chain = workflow.compile()"
@@ -529,7 +540,9 @@
"\n",
"\n",
"@custom_chain\n",
- "def chain(input: Dict[str, Any], **kwargs) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n",
+ "def chain(\n",
+ " input: Dict[str, Any], **kwargs\n",
+ ") -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n",
" \"\"\"\n",
" Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events\n",
" and invoke interface and OpenTelemetry tracing.\n",
@@ -546,7 +559,7 @@
"\n",
" # Stream LLM response\n",
" for chunk in response_chain.stream(\n",
- " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n",
+ " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n",
" ):\n",
" yield OnChatModelStreamEvent(data={\"chunk\": chunk})"
]
@@ -703,7 +716,7 @@
"source": [
"scored_data[\"user\"] = scored_data[\"human_message\"].apply(lambda x: x[\"content\"])\n",
"scored_data[\"reference\"] = scored_data[\"ai_message\"].apply(lambda x: x[\"content\"])\n",
- "scored_data\n"
+ "scored_data"
]
},
{
@@ -771,7 +784,7 @@
"metadata": {},
"outputs": [],
"source": [
- "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}\n"
+ "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}"
]
},
{
@@ -791,10 +804,15 @@
"metadata": {},
"outputs": [],
"source": [
- "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n",
+ "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n",
"\n",
"metrics = [custom_faithfulness_metric]\n",
- "eval_task = EvalTask(dataset=scored_data, metrics=metrics, experiment=experiment_name, metric_column_mapping={\"user\":\"prompt\"} )\n",
+ "eval_task = EvalTask(\n",
+ " dataset=scored_data,\n",
+ " metrics=metrics,\n",
+ " experiment=experiment_name,\n",
+ " metric_column_mapping={\"user\": \"prompt\"},\n",
+ ")\n",
"eval_result = eval_task.evaluate()"
]
},
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py
index 9c61d3cddfa..90b8b2b0bae 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py
@@ -28,7 +28,6 @@
class SideBar:
-
def __init__(self, st) -> None:
self.st = st
@@ -36,28 +35,36 @@ def init_side_bar(self):
with self.st.sidebar:
self.url_input_field = self.st.text_input(
label="Service URL",
- value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL)
+ value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL),
)
self.should_authenticate_request = self.st.checkbox(
label="Authenticate request",
value=False,
- help="If checked, any request to the server will contain an"
+ help="If checked, any request to the server will contain an"
"Identity token to allow authentication. "
"See the Cloud Run documentation to know more about authentication:"
- "https://cloud.google.com/run/docs/authenticating/service-to-service"
+ "https://cloud.google.com/run/docs/authenticating/service-to-service",
)
col1, col2, col3 = self.st.columns(3)
with col1:
if self.st.button("+ New chat"):
- if len(self.st.session_state.user_chats[self.st.session_state['session_id']][
- "messages"]) > 0:
+ if (
+ len(
+ self.st.session_state.user_chats[
+ self.st.session_state["session_id"]
+ ]["messages"]
+ )
+ > 0
+ ):
self.st.session_state.run_id = None
- self.st.session_state['session_id'] = str(uuid.uuid4())
+ self.st.session_state["session_id"] = str(uuid.uuid4())
self.st.session_state.session_db.get_session(
- session_id=self.st.session_state['session_id'],
+ session_id=self.st.session_state["session_id"],
)
- self.st.session_state.user_chats[self.st.session_state['session_id']] = {
+ self.st.session_state.user_chats[
+ self.st.session_state["session_id"]
+ ] = {
"title": EMPTY_CHAT_NAME,
"messages": [],
}
@@ -66,16 +73,20 @@ def init_side_bar(self):
if self.st.button("Delete chat"):
self.st.session_state.run_id = None
self.st.session_state.session_db.clear()
- self.st.session_state.user_chats.pop(self.st.session_state['session_id'])
+ self.st.session_state.user_chats.pop(
+ self.st.session_state["session_id"]
+ )
if len(self.st.session_state.user_chats) > 0:
chat_id = list(self.st.session_state.user_chats.keys())[0]
- self.st.session_state['session_id'] = chat_id
+ self.st.session_state["session_id"] = chat_id
self.st.session_state.session_db.get_session(
- session_id=self.st.session_state['session_id'],
+ session_id=self.st.session_state["session_id"],
)
else:
- self.st.session_state['session_id'] = str(uuid.uuid4())
- self.st.session_state.user_chats[self.st.session_state['session_id']] = {
+ self.st.session_state["session_id"] = str(uuid.uuid4())
+ self.st.session_state.user_chats[
+ self.st.session_state["session_id"]
+ ] = {
"title": EMPTY_CHAT_NAME,
"messages": [],
}
@@ -84,47 +95,54 @@ def init_side_bar(self):
save_chat(self.st)
self.st.subheader("Recent") # Style the heading
-
+
all_chats = list(reversed(self.st.session_state.user_chats.items()))
for chat_id, chat in all_chats[:NUM_CHAT_IN_RECENT]:
if self.st.button(chat["title"], key=chat_id):
self.st.session_state.run_id = None
- self.st.session_state['session_id'] = chat_id
+ self.st.session_state["session_id"] = chat_id
self.st.session_state.session_db.get_session(
- session_id=self.st.session_state['session_id'],
+ session_id=self.st.session_state["session_id"],
)
with self.st.expander("Other chats"):
for chat_id, chat in all_chats[NUM_CHAT_IN_RECENT:]:
if self.st.button(chat["title"], key=chat_id):
self.st.session_state.run_id = None
- self.st.session_state['session_id'] = chat_id
+ self.st.session_state["session_id"] = chat_id
self.st.session_state.session_db.get_session(
- session_id=self.st.session_state['session_id'],
+ session_id=self.st.session_state["session_id"],
)
self.st.divider()
self.st.header("Upload files from local")
bucket_name = self.st.text_input(
label="GCS Bucket for upload",
- value=os.environ.get("BUCKET_NAME","gs://your-bucket-name")
+ value=os.environ.get("BUCKET_NAME", "gs://your-bucket-name"),
)
- if 'checkbox_state' not in self.st.session_state:
+ if "checkbox_state" not in self.st.session_state:
self.st.session_state.checkbox_state = True
-
+
self.st.session_state.checkbox_state = self.st.checkbox(
- "Upload to GCS first (suggested)",
- value=True,
- help=HELP_GCS_CHECKBOX
+ "Upload to GCS first (suggested)", value=True, help=HELP_GCS_CHECKBOX
)
self.uploaded_files = self.st.file_uploader(
- label="Send files from local", accept_multiple_files=True,
+ label="Send files from local",
+ accept_multiple_files=True,
key=f"uploader_images_{self.st.session_state.uploader_key}",
type=[
- "png", "jpg", "jpeg", "txt", "docx",
- "pdf", "rtf", "csv", "tsv", "xlsx"
- ]
+ "png",
+ "jpg",
+ "jpeg",
+ "txt",
+ "docx",
+ "pdf",
+ "rtf",
+ "csv",
+ "tsv",
+ "xlsx",
+ ],
)
if self.uploaded_files and self.st.session_state.checkbox_state:
upload_files_to_gcs(self.st, bucket_name, self.uploaded_files)
@@ -136,7 +154,7 @@ def init_side_bar(self):
"GCS uris (comma-separated)",
value=self.st.session_state["gcs_uris_to_be_sent"],
key=f"upload_text_area_{self.st.session_state.uploader_key}",
- help=HELP_MESSAGE_MULTIMODALITY
+ help=HELP_MESSAGE_MULTIMODALITY,
)
-
- self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}")
\ No newline at end of file
+
+ self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}")
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py
index 11cad777de2..12947fac963 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from functools import partial
import json
import uuid
-from functools import partial
from langchain_core.messages import HumanMessage
from side_bar import SideBar
+import streamlit as st
from streamlit_feedback import streamlit_feedback
from style.app_markdown import markdown_str
from utils.local_chat_history import LocalChatMessageHistory
@@ -25,8 +26,6 @@
from utils.multimodal_utils import format_content, get_parts_from_files
from utils.stream_handler import Client, StreamHandler, get_chain_response
-import streamlit as st
-
USER = "my_user"
EMPTY_CHAT_NAME = "Empty chat"
@@ -34,14 +33,14 @@
page_title="Playground",
layout="wide",
initial_sidebar_state="auto",
- menu_items=None
+ menu_items=None,
)
st.title("Playground")
st.markdown(markdown_str, unsafe_allow_html=True)
# First time Init of session variables
if "user_chats" not in st.session_state:
- st.session_state['session_id'] = str(uuid.uuid4())
+ st.session_state["session_id"] = str(uuid.uuid4())
st.session_state.uploader_key = 0
st.session_state.run_id = None
st.session_state.user_id = USER
@@ -49,37 +48,41 @@
st.session_state["gcs_uris_to_be_sent"] = ""
st.session_state.modified_prompt = None
st.session_state.session_db = LocalChatMessageHistory(
- session_id=st.session_state['session_id'],
- user_id=st.session_state['user_id'],
+ session_id=st.session_state["session_id"],
+ user_id=st.session_state["user_id"],
)
st.session_state.user_chats = st.session_state.session_db.get_all_conversations()
- st.session_state.user_chats[st.session_state['session_id']] = {
+ st.session_state.user_chats[st.session_state["session_id"]] = {
"title": EMPTY_CHAT_NAME,
"messages": [],
- }
+ }
side_bar = SideBar(st=st)
side_bar.init_side_bar()
-client = Client(url=side_bar.url_input_field, authenticate_request=side_bar.should_authenticate_request)
+client = Client(
+ url=side_bar.url_input_field,
+ authenticate_request=side_bar.should_authenticate_request,
+)
# Write all messages of current conversation
-messages = st.session_state.user_chats[st.session_state['session_id']]["messages"]
+messages = st.session_state.user_chats[st.session_state["session_id"]]["messages"]
for i, message in enumerate(messages):
with st.chat_message(message["type"]):
if message["type"] == "ai":
-
if message.get("tool_calls") and len(message.get("tool_calls")) > 0:
- tool_expander = st.expander(
- label="Tool Calls:",
- expanded=False)
+ tool_expander = st.expander(label="Tool Calls:", expanded=False)
with tool_expander:
for index, tool_call in enumerate(message["tool_calls"]):
# ruff: noqa: E501
- tool_call_output = message["additional_kwargs"]["tool_calls_outputs"][index]
- msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \
- f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \
- f"\n\n**output:**\n " \
- f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```"
+ tool_call_output = message["additional_kwargs"][
+ "tool_calls_outputs"
+ ][index]
+ msg = (
+ f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n"
+ f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n"
+ f"\n\n**output:**\n "
+ f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```"
+ )
st.markdown(msg, unsafe_allow_html=True)
st.markdown(format_content(message["content"]), unsafe_allow_html=True)
@@ -88,29 +91,25 @@
refresh_button = f"{i}_refresh"
delete_button = f"{i}_delete"
content = message["content"]
-
- if isinstance(message["content"],list):
+
+ if isinstance(message["content"], list):
content = message["content"][-1]["text"]
with col1:
- st.button(
- label="✎",
- key=edit_button,
- type="primary"
- )
+ st.button(label="✎", key=edit_button, type="primary")
if message["type"] == "human":
with col2:
st.button(
label="⟳",
key=refresh_button,
type="primary",
- on_click=partial(MessageEditing.refresh_message, st, i, content)
+ on_click=partial(MessageEditing.refresh_message, st, i, content),
)
with col3:
st.button(
label="X",
key=delete_button,
type="primary",
- on_click=partial(MessageEditing.delete_message, st, i)
+ on_click=partial(MessageEditing.delete_message, st, i),
)
if st.session_state[edit_button]:
@@ -118,55 +117,52 @@
"Edit your message:",
value=content,
key=f"edit_box_{i}",
- on_change=partial(MessageEditing.edit_message, st, i, message["type"]))
+ on_change=partial(MessageEditing.edit_message, st, i, message["type"]),
+ )
# Handle new (or modified) user prompt and response
prompt = st.chat_input()
if prompt is None:
prompt = st.session_state.modified_prompt
-
+
if prompt:
st.session_state.modified_prompt = None
parts = get_parts_from_files(
upload_gcs_checkbox=st.session_state.checkbox_state,
- uploaded_files=side_bar.uploaded_files,
- gcs_uris=side_bar.gcs_uris
+ uploaded_files=side_bar.uploaded_files,
+ gcs_uris=side_bar.gcs_uris,
)
st.session_state["gcs_uris_to_be_sent"] = ""
- parts.append(
- {
- "type": "text",
- "text": prompt
- }
+ parts.append({"type": "text", "text": prompt})
+ st.session_state.user_chats[st.session_state["session_id"]]["messages"].append(
+ HumanMessage(content=parts).model_dump()
)
- st.session_state.user_chats[st.session_state['session_id']]["messages"].append(
- HumanMessage(content=parts).model_dump())
human_message = st.chat_message("human")
with human_message:
existing_user_input = format_content(parts)
user_input = st.markdown(existing_user_input, unsafe_allow_html=True)
-
+
ai_message = st.chat_message("ai")
with ai_message:
status = st.status("Generating answer🤖")
stream_handler = StreamHandler(st=st)
- get_chain_response(
- st=st,
- client=client,
- stream_handler=stream_handler
- )
+ get_chain_response(st=st, client=client, stream_handler=stream_handler)
status.update(label="Finished!", state="complete", expanded=False)
- if st.session_state.user_chats[st.session_state['session_id']][
- "title"] == EMPTY_CHAT_NAME:
+ if (
+ st.session_state.user_chats[st.session_state["session_id"]]["title"]
+ == EMPTY_CHAT_NAME
+ ):
st.session_state.session_db.set_title(
- st.session_state.user_chats[st.session_state['session_id']]
+ st.session_state.user_chats[st.session_state["session_id"]]
)
- st.session_state.session_db.upsert_session(st.session_state.user_chats[st.session_state['session_id']])
+ st.session_state.session_db.upsert_session(
+ st.session_state.user_chats[st.session_state["session_id"]]
+ )
if len(parts) > 1:
st.session_state.uploader_key += 1
st.rerun()
@@ -175,12 +171,10 @@
feedback = streamlit_feedback(
feedback_type="faces",
optional_text_label="[Optional] Please provide an explanation",
- key=f"feedback-{st.session_state.run_id}"
+ key=f"feedback-{st.session_state.run_id}",
)
if feedback is not None:
client.log_feedback(
feedback_dict=feedback,
run_id=st.session_state.run_id,
)
-
-
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py
index b91a810e9c3..b1332580450 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py
@@ -34,4 +34,4 @@
color: !important;
}
-"""
\ No newline at end of file
+"""
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py
index 6e3967ccd39..ebbe0110ed7 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from datetime import datetime
+import os
-import yaml
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import HumanMessage
from utils.title_summary import chain_title
+import yaml
class LocalChatMessageHistory(BaseChatMessageHistory):
def __init__(
- self,
- user_id: str,
- session_id: str = "default",
- base_dir: str = ".streamlit_chats"
+ self,
+ user_id: str,
+ session_id: str = "default",
+ base_dir: str = ".streamlit_chats",
) -> None:
self.user_id = user_id
self.session_id = session_id
@@ -43,12 +43,13 @@ def get_session(self, session_id):
def get_all_conversations(self):
conversations = {}
for filename in os.listdir(self.user_dir):
- if filename.endswith('.yaml'):
+ if filename.endswith(".yaml"):
file_path = os.path.join(self.user_dir, filename)
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
conversation = yaml.safe_load(f)
if not isinstance(conversation, list) or len(conversation) > 1:
- raise ValueError(f"""Invalid format in {file_path}.
+ raise ValueError(
+ f"""Invalid format in {file_path}.
YAML file can only contain one conversation with the following
structure.
- messages:
@@ -61,17 +62,18 @@ def get_all_conversations(self):
conversation["title"] = filename
conversations[filename[:-5]] = conversation
return dict(
- sorted(conversations.items(), key=lambda x: x[1].get('update_time', '')))
+ sorted(conversations.items(), key=lambda x: x[1].get("update_time", ""))
+ )
def upsert_session(self, session) -> None:
- session['update_time'] = datetime.now().isoformat()
- with open(self.session_file, 'w') as f:
+ session["update_time"] = datetime.now().isoformat()
+ with open(self.session_file, "w") as f:
yaml.dump(
[session],
f,
allow_unicode=True,
default_flow_style=False,
- encoding='utf-8'
+ encoding="utf-8",
)
def set_title(self, session) -> None:
@@ -100,4 +102,4 @@ def set_title(self, session) -> None:
def clear(self) -> None:
if os.path.exists(self.session_file):
- os.remove(self.session_file)
\ No newline at end of file
+ os.remove(self.session_file)
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py
index e696714657d..7b829d67d26 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py
@@ -12,27 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-class MessageEditing:
+class MessageEditing:
@staticmethod
def edit_message(st, button_idx, message_type):
button_id = f"edit_box_{button_idx}"
if message_type == "human":
- messages = st.session_state.user_chats[st.session_state['session_id']]["messages"]
- st.session_state.user_chats[st.session_state['session_id']][
- "messages"] = messages[:button_idx]
+ messages = st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ]
+ st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ] = messages[:button_idx]
st.session_state.modified_prompt = st.session_state[button_id]
else:
- st.session_state.user_chats[st.session_state['session_id']]["messages"][
- button_idx]["content"] = st.session_state[button_id]
-
+ st.session_state.user_chats[st.session_state["session_id"]]["messages"][
+ button_idx
+ ]["content"] = st.session_state[button_id]
+
@staticmethod
def refresh_message(st, button_idx, content):
- messages = st.session_state.user_chats[st.session_state['session_id']]["messages"]
- st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx]
+ messages = st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ]
+ st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ] = messages[:button_idx]
st.session_state.modified_prompt = content
@staticmethod
def delete_message(st, button_idx):
- messages = st.session_state.user_chats[st.session_state['session_id']]["messages"]
- st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx]
\ No newline at end of file
+ messages = st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ]
+ st.session_state.user_chats[st.session_state["session_id"]][
+ "messages"
+ ] = messages[:button_idx]
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py
index 94d35231058..4c1c8a49738 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py
@@ -17,12 +17,16 @@
from google.cloud import storage
-HELP_MESSAGE_MULTIMODALITY = "To ensure Gemini models can access the URIs you " \
- "provide, store all URIs in buckets within the same " \
- "GCP Project that Gemini uses."
+HELP_MESSAGE_MULTIMODALITY = (
+ "To ensure Gemini models can access the URIs you "
+ "provide, store all URIs in buckets within the same "
+ "GCP Project that Gemini uses."
+)
-HELP_GCS_CHECKBOX = "Enabling GCS upload will increase app performance by avoiding to" \
- " pass large byte strings to the model"
+HELP_GCS_CHECKBOX = (
+ "Enabling GCS upload will increase app performance by avoiding to"
+ " pass large byte strings to the model"
+)
def format_content(content):
@@ -41,9 +45,12 @@ def format_content(content):
if part["type"] == "image_url":
image_url = part["image_url"]["url"]
image_markdown = f''
- markdown = markdown + f"""
+ markdown = (
+ markdown
+ + f"""
- {image_markdown}
"""
+ )
if part["type"] == "media":
# Local other media
if "data" in part:
@@ -54,19 +61,27 @@ def format_content(content):
if "image" in part["mime_type"]:
image_url = gs_uri_to_https_url(part["file_uri"])
image_markdown = f''
- markdown = markdown + f"""
+ markdown = (
+ markdown
+ + f"""
- {image_markdown}
"""
+ )
# GCS other media
else:
- image_url = gs_uri_to_https_url(part['file_uri'])
- markdown = markdown + f"- Remote media: " \
- f"[{part['file_uri']}]({image_url})\n"
- markdown = markdown + f"""
+ image_url = gs_uri_to_https_url(part["file_uri"])
+ markdown = (
+ markdown + f"- Remote media: "
+ f"[{part['file_uri']}]({image_url})\n"
+ )
+ markdown = (
+ markdown
+ + f"""
{text}"""
+ )
return markdown
-
+
def get_gcs_blob_mime_type(gcs_uri):
"""Fetches the MIME type (content type) of a Google Cloud Storage blob.
@@ -103,17 +118,17 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris):
content = {
"type": "image_url",
"image_url": {
- "url": f"data:{uploaded_file.type};base64," \
- f"{base64.b64encode(im_bytes).decode('utf-8')}"
+ "url": f"data:{uploaded_file.type};base64,"
+ f"{base64.b64encode(im_bytes).decode('utf-8')}"
},
"file_name": uploaded_file.name,
}
else:
content = {
"type": "media",
- "data": base64.b64encode(im_bytes).decode('utf-8'),
+ "data": base64.b64encode(im_bytes).decode("utf-8"),
"file_name": uploaded_file.name,
- "mime_type": uploaded_file.type
+ "mime_type": uploaded_file.type,
}
parts.append(content)
@@ -122,11 +137,12 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris):
content = {
"type": "media",
"file_uri": uri,
- "mime_type": get_gcs_blob_mime_type(uri)
+ "mime_type": get_gcs_blob_mime_type(uri),
}
parts.append(content)
return parts
+
def upload_bytes_to_gcs(bucket_name, blob_name, file_bytes, content_type=None):
"""Uploads a bytes object to Google Cloud Storage and returns the GCS URI.
@@ -177,17 +193,17 @@ def gs_uri_to_https_url(gs_uri):
def upload_files_to_gcs(st, bucket_name, files_to_upload):
- bucket_name = bucket_name.replace("gs://", "")
- uploaded_uris = []
- for file in files_to_upload:
- if file:
- file_bytes = file.read()
- gcs_uri = upload_bytes_to_gcs(
- bucket_name=bucket_name,
- blob_name=file.name,
- file_bytes=file_bytes,
- content_type=file.type
- )
- uploaded_uris.append(gcs_uri)
- st.session_state.uploader_key += 1
- st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris)
+ bucket_name = bucket_name.replace("gs://", "")
+ uploaded_uris = []
+ for file in files_to_upload:
+ if file:
+ file_bytes = file.read()
+ gcs_uri = upload_bytes_to_gcs(
+ bucket_name=bucket_name,
+ blob_name=file.name,
+ file_bytes=file_bytes,
+ content_type=file.type,
+ )
+ uploaded_uris.append(gcs_uri)
+ st.session_state.uploader_key += 1
+ st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris)
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py
index 9c560687152..bfe0e07c12d 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py
@@ -17,12 +17,11 @@
from urllib.parse import urljoin
import google.auth
+from google.auth.exceptions import DefaultCredentialsError
import google.auth.transport.requests
import google.oauth2.id_token
-import requests
-from google.auth.exceptions import DefaultCredentialsError
from langchain_core.messages import AIMessage
-
+import requests
import streamlit as st
@@ -78,25 +77,24 @@ def log_feedback(self, feedback_dict, run_id):
"Content-Type": "application/json",
}
if self.authenticate_request:
- headers["Authorization"] = f"Bearer {self.id_token}"
+ headers["Authorization"] = f"Bearer {self.id_token}"
requests.post(url, data=json.dumps(feedback_dict), headers=headers)
- def stream_events(self, data: Dict[str, Any]) -> Generator[
- Dict[str, Any], None, None]:
+ def stream_events(
+ self, data: Dict[str, Any]
+ ) -> Generator[Dict[str, Any], None, None]:
"""Stream events from the server, yielding parsed event data."""
- headers = {
- "Content-Type": "application/json",
- "Accept": "text/event-stream"
- }
+ headers = {"Content-Type": "application/json", "Accept": "text/event-stream"}
if self.authenticate_request:
- headers["Authorization"] = f"Bearer {self.id_token}"
+ headers["Authorization"] = f"Bearer {self.id_token}"
- with requests.post(self.url, json={"input": data}, headers=headers,
- stream=True) as response:
+ with requests.post(
+ self.url, json={"input": data}, headers=headers, stream=True
+ ) as response:
for line in response.iter_lines():
if line:
try:
- event = json.loads(line.decode('utf-8'))
+ event = json.loads(line.decode("utf-8"))
# print(event)
yield event
except json.JSONDecodeError:
@@ -125,9 +123,6 @@ def new_status(self, status_update: str) -> None:
self.tool_expander.markdown(status_update)
-
-
-
class EventProcessor:
"""Processes events from the stream and updates the UI accordingly."""
@@ -144,14 +139,14 @@ def __init__(self, st, client, stream_handler):
def process_events(self):
"""Process events from the stream, handling each event type appropriately."""
- messages = \
- self.st.session_state.user_chats[self.st.session_state['session_id']][
- "messages"]
+ messages = self.st.session_state.user_chats[
+ self.st.session_state["session_id"]
+ ]["messages"]
stream = self.client.stream_events(
data={
"messages": messages,
- "user_id": self.st.session_state['user_id'],
- "session_id": self.st.session_state['session_id'],
+ "user_id": self.st.session_state["user_id"],
+ "session_id": self.st.session_state["session_id"],
}
)
@@ -173,11 +168,13 @@ def process_events(self):
def handle_metadata(self, event: Dict[str, Any]) -> None:
"""Handle metadata events."""
- self.current_run_id = event['data'].get('run_id')
+ self.current_run_id = event["data"].get("run_id")
def handle_tool_start(self, event: Dict[str, Any]) -> None:
"""Handle the start of a tool or retriever execution."""
- msg = f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`"
+ msg = (
+ f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`"
+ )
self.stream_handler.new_status(msg)
def handle_tool_end(self, event: Dict[str, Any]) -> None:
@@ -185,24 +182,26 @@ def handle_tool_end(self, event: Dict[str, Any]) -> None:
data = event["data"]
# Support tool events
if isinstance(data["output"], dict):
- tool_id = data["output"].get('tool_call_id', None)
- tool_name = data["output"].get('name', 'Unknown Tool')
+ tool_id = data["output"].get("tool_call_id", None)
+ tool_name = data["output"].get("name", "Unknown Tool")
# Support retriever events
else:
tool_id = event.get("id", "None")
tool_name = event.get("name", event["event"])
- tool_input = data['input']
- tool_output = data['output']
+ tool_input = data["input"]
+ tool_output = data["output"]
tool_call = {"id": tool_id, "name": tool_name, "args": tool_input}
self.tool_calls.append(tool_call)
- tool_call_outputs = {"id":tool_id, "output":tool_output}
+ tool_call_outputs = {"id": tool_id, "output": tool_output}
self.tool_calls_outputs.append(tool_call_outputs)
- msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \
- f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \
- f"\n\n**output:**\n " \
- f"```\n{json.dumps(tool_output, indent=2)}\n```"
+ msg = (
+ f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n"
+ f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n"
+ f"\n\n**output:**\n "
+ f"```\n{json.dumps(tool_output, indent=2)}\n```"
+ )
self.stream_handler.new_status(msg)
def handle_chat_model_stream(self, event: Dict[str, Any]) -> None:
@@ -211,7 +210,7 @@ def handle_chat_model_stream(self, event: Dict[str, Any]) -> None:
content = data["chunk"]["content"]
self.additional_kwargs = {
**self.additional_kwargs,
- **data["chunk"]["additional_kwargs"]
+ **data["chunk"]["additional_kwargs"],
}
if content and len(content.strip()) > 0:
self.final_content += content
@@ -225,12 +224,13 @@ def handle_end(self, event: Dict[str, Any]) -> None:
content=self.final_content,
tool_calls=self.tool_calls,
id=self.current_run_id,
- additional_kwargs=additional_kwargs
+ additional_kwargs=additional_kwargs,
).model_dump()
- session = self.st.session_state['session_id']
+ session = self.st.session_state["session_id"]
self.st.session_state.user_chats[session]["messages"].append(final_message)
self.st.session_state.run_id = self.current_run_id
+
def get_chain_response(st, client, stream_handler):
"""Process the chain response update the Streamlit UI.
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py
index a34efa51abd..443263c1a4f 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py
@@ -63,4 +63,4 @@
MessagesPlaceholder(variable_name="messages"),
])
-chain_title = title_template | llm
\ No newline at end of file
+chain_title = title_template | llm
diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py
index 28f40eee878..42cc057a8c5 100644
--- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py
+++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py
@@ -20,8 +20,6 @@
SAVED_CHAT_PATH = str(os.getcwd()) + "/.saved_chats"
-
-
def preprocess_text(text):
if text[0] == "\n":
text = text[1:]
@@ -55,6 +53,6 @@ def save_chat(st):
file,
allow_unicode=True,
default_flow_style=False,
- encoding='utf-8',
+ encoding="utf-8",
)
st.toast(f"Chat saved to path: ↓ {Path(SAVED_CHAT_PATH) / filename}")
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py
index 25615c9b1fd..0f94b40b7a0 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py
@@ -14,10 +14,9 @@
import logging
-import pytest
-from langchain_core.messages import AIMessageChunk
-
from app.patterns.langgraph_dummy_agent.chain import chain
+from langchain_core.messages import AIMessageChunk
+import pytest
CHAIN_NAME = "Langgraph agent"
@@ -33,19 +32,21 @@ async def test_langgraph_chain_astream_events() -> None:
events = [event async for event in chain.astream_events(input_dict, version="v2")]
- assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \
- f"got {len(events)}"
+ assert len(events) > 1, (
+ f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}"
+ )
on_chain_stream_events = [
event for event in events if event["event"] == "on_chat_model_stream"
]
- assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \
- f" for {CHAIN_NAME} chain"
+ assert on_chain_stream_events, (
+ f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain"
+ )
for event in on_chain_stream_events:
- assert AIMessageChunk.model_validate(event["data"]["chunk"]), (
- f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
- )
+ assert AIMessageChunk.model_validate(
+ event["data"]["chunk"]
+ ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
logging.info(f"All assertions passed for {CHAIN_NAME} chain")
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py
index 951631b177b..36ace08010b 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py
@@ -14,10 +14,9 @@
import logging
-import pytest
-from langchain_core.messages import AIMessageChunk
-
from app.patterns.custom_rag_qa.chain import chain
+from langchain_core.messages import AIMessageChunk
+import pytest
CHAIN_NAME = "Rag QA"
@@ -33,19 +32,21 @@ async def test_rag_chain_astream_events() -> None:
events = [event async for event in chain.astream_events(input_dict, version="v2")]
- assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \
- f"got {len(events)}"
+ assert len(events) > 1, (
+ f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}"
+ )
on_chain_stream_events = [
event for event in events if event["event"] == "on_chat_model_stream"
- ]
+ ]
- assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \
- f" for {CHAIN_NAME} chain"
+ assert on_chain_stream_events, (
+ f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain"
+ )
for event in on_chain_stream_events:
- assert AIMessageChunk.model_validate(event["data"]["chunk"]), (
- f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
- )
+ assert AIMessageChunk.model_validate(
+ event["data"]["chunk"]
+ ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
logging.info(f"All assertions passed for {CHAIN_NAME} chain")
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py
index 9b9da19f584..73c27f93105 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py
@@ -14,10 +14,9 @@
import logging
-import pytest
-from langchain_core.messages import AIMessageChunk
-
from app.chain import chain
+from langchain_core.messages import AIMessageChunk
+import pytest
CHAIN_NAME = "Default"
@@ -33,19 +32,21 @@ async def test_default_chain_astream_events() -> None:
events = [event async for event in chain.astream_events(input_dict, version="v2")]
- assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \
- f"got {len(events)}"
+ assert len(events) > 1, (
+ f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}"
+ )
on_chain_stream_events = [
event for event in events if event["event"] == "on_chat_model_stream"
- ]
+ ]
- assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \
- f" for {CHAIN_NAME} chain"
+ assert on_chain_stream_events, (
+ f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain"
+ )
for event in on_chain_stream_events:
- assert AIMessageChunk.model_validate(event["data"]["chunk"]), (
- f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
- )
+ assert AIMessageChunk.model_validate(
+ event["data"]["chunk"]
+ ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}"
logging.info(f"All assertions passed for {CHAIN_NAME} chain")
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py
index c10e3c3e4d6..1c76307086f 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py
@@ -18,8 +18,8 @@
import sys
import threading
import time
-import uuid
from typing import Any, Iterator
+import uuid
import pytest
import requests
@@ -35,11 +35,13 @@
HEADERS = {"Content-Type": "application/json"}
+
def log_output(pipe: Any, log_func: Any) -> None:
"""Log the output from the given pipe."""
- for line in iter(pipe.readline, ''):
+ for line in iter(pipe.readline, ""):
log_func(line.strip())
+
def start_server() -> subprocess.Popen[str]:
"""Start the FastAPI server using subprocess and log its output."""
command = [
@@ -58,18 +60,15 @@ def start_server() -> subprocess.Popen[str]:
# Start threads to log stdout and stderr in real-time
threading.Thread(
- target=log_output,
- args=(process.stdout, logger.info),
- daemon=True
+ target=log_output, args=(process.stdout, logger.info), daemon=True
).start()
threading.Thread(
- target=log_output,
- args=(process.stderr, logger.error),
- daemon=True
+ target=log_output, args=(process.stderr, logger.error), daemon=True
).start()
return process
+
def wait_for_server(timeout: int = 60, interval: int = 1) -> bool:
"""Wait for the server to be ready."""
start_time = time.time()
@@ -85,6 +84,7 @@ def wait_for_server(timeout: int = 60, interval: int = 1) -> bool:
logger.error(f"Server did not become ready within {timeout} seconds")
return False
+
@pytest.fixture(scope="session")
def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]:
"""Pytest fixture to start and stop the server for testing."""
@@ -103,6 +103,7 @@ def stop_server() -> None:
request.addfinalizer(stop_server)
yield server_process
+
def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
"""Test the chat stream functionality."""
logger.info("Starting chat stream test")
@@ -112,10 +113,10 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
"messages": [
{"role": "user", "content": "Hello, AI!"},
{"role": "ai", "content": "Hello!"},
- {"role": "user", "content": "What cooking recipes do you suggest?"}
+ {"role": "user", "content": "What cooking recipes do you suggest?"},
],
"user_id": "test-user",
- "session_id": "test-session"
+ "session_id": "test-session",
}
}
@@ -128,28 +129,35 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
logger.info(f"Received {len(events)} events")
assert len(events) > 2, f"Expected more than 2 events, got {len(events)}."
- assert events[0]["event"] == "metadata", f"First event should be 'metadata', " \
- f"got {events[0]['event']}"
+ assert events[0]["event"] == "metadata", (
+ f"First event should be 'metadata', " f"got {events[0]['event']}"
+ )
assert "run_id" in events[0]["data"], "Missing 'run_id' in metadata"
event_types = [event["event"] for event in events]
assert "on_chat_model_stream" in event_types, "Missing 'on_chat_model_stream' event"
- assert events[-1]["event"] == "end", f"Last event should be 'end', " \
- f"got {events[-1]['event']}"
+ assert events[-1]["event"] == "end", (
+ f"Last event should be 'end', " f"got {events[-1]['event']}"
+ )
logger.info("Test completed successfully")
+
def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None:
"""Test the chat stream error handling."""
logger.info("Starting chat stream error handling test")
data = {"input": [{"role": "invalid_role", "content": "Cause an error"}]}
- response = requests.post(STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10)
+ response = requests.post(
+ STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10
+ )
- assert response.status_code == 422, f"Expected status code 422, " \
- f"got {response.status_code}"
+ assert response.status_code == 422, (
+ f"Expected status code 422, " f"got {response.status_code}"
+ )
logger.info("Error handling test completed successfully")
+
def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None:
"""
Test the feedback collection endpoint (/feedback) to ensure it properly
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py
index 0e356953c45..73af9fd5f44 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py
@@ -26,17 +26,17 @@ class ChatStreamUser(HttpUser):
def chat_stream(self) -> None:
headers = {"Content-Type": "application/json"}
if os.environ.get("_ID_TOKEN"):
- headers['Authorization'] = f'Bearer {os.environ["_ID_TOKEN"]}'
+ headers["Authorization"] = f'Bearer {os.environ["_ID_TOKEN"]}'
data = {
"input": {
"messages": [
{"role": "user", "content": "Hello, AI!"},
{"role": "ai", "content": "Hello!"},
- {"role": "user", "content": "Who are you?"}
+ {"role": "user", "content": "Who are you?"},
],
"user_id": "test-user",
- "session_id": "test-session"
+ "session_id": "test-session",
}
}
@@ -48,7 +48,7 @@ def chat_stream(self) -> None:
json=data,
catch_response=True,
name="/stream_events first event",
- stream=True
+ stream=True,
) as response:
if response.status_code == 200:
events = []
@@ -73,9 +73,9 @@ def chat_stream(self) -> None:
response_time=total_time * 1000, # Convert to milliseconds
response_length=len(json.dumps(events)),
response=response,
- context={}
+ context={},
)
else:
response.failure("Unexpected response structure")
else:
- response.failure(f"Unexpected status code: {response.status_code}")
\ No newline at end of file
+ response.failure(f"Unexpected status code: {response.status_code}")
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py
index f760cf500d6..51565628faf 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py
@@ -17,13 +17,12 @@
from typing import Any
from unittest.mock import patch
-import pytest
+from app.server import app
+from app.utils.input_types import InputChat
from fastapi.testclient import TestClient
from httpx import AsyncClient
from langchain_core.messages import HumanMessage
-
-from app.server import app
-from app.utils.input_types import InputChat
+import pytest
# Set up logging
logging.basicConfig(level=logging.INFO)
@@ -41,7 +40,7 @@ def sample_input_chat() -> InputChat:
return InputChat(
user_id="test-user",
session_id="test-session",
- messages=[HumanMessage(content="What is the meaning of life?")]
+ messages=[HumanMessage(content="What is the meaning of life?")],
)
@@ -62,7 +61,7 @@ class AsyncIterator:
def __init__(self, seq: list) -> None:
self.iter = iter(seq)
- def __aiter__(self) -> 'AsyncIterator':
+ def __aiter__(self) -> "AsyncIterator":
return self
async def __anext__(self) -> Any:
@@ -77,7 +76,7 @@ def mock_chain() -> Any:
"""
Fixture to mock the chain object used in the application.
"""
- with patch('app.server.chain') as mock:
+ with patch("app.server.chain") as mock:
yield mock
@@ -94,8 +93,8 @@ async def test_stream_chat_events(mock_chain: Any) -> None:
"messages": [
{"role": "user", "content": "Hello, AI!"},
{"role": "ai", "content": "Hello!"},
- {"role": "user", "content": "What cooking recipes do you suggest?"}
- ]
+ {"role": "user", "content": "What cooking recipes do you suggest?"},
+ ],
}
}
@@ -107,8 +106,9 @@ async def test_stream_chat_events(mock_chain: Any) -> None:
mock_chain.astream_events.return_value = AsyncIterator(mock_events)
- with patch('uuid.uuid4', return_value=mock_uuid), \
- patch('app.server.Traceloop.set_association_properties'):
+ with patch("uuid.uuid4", return_value=mock_uuid), patch(
+ "app.server.Traceloop.set_association_properties"
+ ):
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/stream_events", json=input_data)
diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py
index 73385724109..e584ee0b704 100644
--- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py
+++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py
@@ -14,41 +14,47 @@
from unittest.mock import Mock, patch
-import pytest
+from app.utils.tracing import CloudTraceLoggingSpanExporter
from google.cloud import logging as gcp_logging
from google.cloud import storage
from opentelemetry.sdk.trace import ReadableSpan
-
-from app.utils.tracing import CloudTraceLoggingSpanExporter
+import pytest
@pytest.fixture
def mock_logging_client() -> Mock:
return Mock(spec=gcp_logging.Client)
+
@pytest.fixture
def mock_storage_client() -> Mock:
return Mock(spec=storage.Client)
+
@pytest.fixture
-def exporter(mock_logging_client: Mock, mock_storage_client: Mock) -> CloudTraceLoggingSpanExporter:
+def exporter(
+ mock_logging_client: Mock, mock_storage_client: Mock
+) -> CloudTraceLoggingSpanExporter:
return CloudTraceLoggingSpanExporter(
project_id="test-project",
logging_client=mock_logging_client,
storage_client=mock_storage_client,
- bucket_name="test-bucket"
+ bucket_name="test-bucket",
)
+
def test_init(exporter: CloudTraceLoggingSpanExporter) -> None:
assert exporter.project_id == "test-project"
assert exporter.bucket_name == "test-bucket"
assert exporter.debug is False
+
def test_ensure_bucket_exists(exporter: CloudTraceLoggingSpanExporter) -> None:
exporter.storage_client.bucket.return_value.exists.return_value = False
exporter._ensure_bucket_exists()
exporter.storage_client.create_bucket.assert_called_once_with("test-bucket")
+
def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None:
span_id = "test-span-id"
content = "test-content"
@@ -56,26 +62,26 @@ def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None:
assert uri == f"gs://test-bucket/spans/{span_id}.json"
exporter.bucket.blob.assert_called_once_with(f"spans/{span_id}.json")
-@patch('json.dumps')
+
+@patch("json.dumps")
def test_process_large_attributes_small_payload(
- mock_json_dumps: Mock,
- exporter: CloudTraceLoggingSpanExporter
+ mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter
) -> None:
- mock_json_dumps.return_value = 'a' * 100 # Small payload
+ mock_json_dumps.return_value = "a" * 100 # Small payload
span_dict = {"attributes": {"key": "value"}}
result = exporter._process_large_attributes(span_dict, "span-id")
assert result == span_dict
-@patch('json.dumps')
+
+@patch("json.dumps")
def test_process_large_attributes_large_payload(
- mock_json_dumps: Mock,
- exporter: CloudTraceLoggingSpanExporter
+ mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter
) -> None:
- mock_json_dumps.return_value = 'a' * (400 * 1024 + 1) # Large payload
+ mock_json_dumps.return_value = "a" * (400 * 1024 + 1) # Large payload
span_dict = {
"attributes": {
"key1": "value1",
- "traceloop.association.properties.key2": "value2"
+ "traceloop.association.properties.key2": "value2",
}
}
result = exporter._process_large_attributes(span_dict, "span-id")
@@ -84,16 +90,19 @@ def test_process_large_attributes_large_payload(
assert "key1" not in result["attributes"]
assert "traceloop.association.properties.key2" in result["attributes"]
-@patch.object(CloudTraceLoggingSpanExporter, '_process_large_attributes')
-def test_export(mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter) -> None:
+
+@patch.object(CloudTraceLoggingSpanExporter, "_process_large_attributes")
+def test_export(
+ mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter
+) -> None:
mock_span = Mock(spec=ReadableSpan)
mock_span.get_span_context.return_value.trace_id = 123
mock_span.get_span_context.return_value.span_id = 456
mock_span.to_json.return_value = '{"key": "value"}'
-
+
mock_process_large_attributes.return_value = {"processed": "data"}
-
+
exporter.export([mock_span])
-
+
mock_process_large_attributes.assert_called_once()
exporter.logger.log_struct.assert_called_once()