Skip to content

Commit

Permalink
🦉 Updates from OwlBot post-processor
Browse files Browse the repository at this point in the history
  • Loading branch information
gcf-owl-bot[bot] committed Sep 17, 2024
1 parent fff519d commit 64a9ba1
Show file tree
Hide file tree
Showing 27 changed files with 487 additions and 384 deletions.
15 changes: 10 additions & 5 deletions gemini/sample-apps/conversational-genai-app-template/app/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@
}

llm = ChatVertexAI(
model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=1024,
safety_settings=safety_settings
model_name="gemini-1.5-flash-001",
temperature=0,
max_output_tokens=1024,
safety_settings=safety_settings,
)


template = ChatPromptTemplate.from_messages(
[
("system", """You are a conversational bot that produce recipes for users based
on a question."""),
MessagesPlaceholder(variable_name="messages")
(
"system",
"""You are a conversational bot that produce recipes for users based
on a question.""",
),
MessagesPlaceholder(variable_name="messages"),
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from functools import partial
import glob
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Callable, Dict, Iterator, List

import nest_asyncio
import pandas as pd
import yaml
from tqdm import tqdm
import yaml

nest_asyncio.apply()


def load_chats(path: str) -> List[Dict[str, Any]]:
"""
Loads a list of chats from a directory or file.
Expand All @@ -44,6 +45,7 @@ def load_chats(path: str) -> List[Dict[str, Any]]:
chats = chats + chats_in_file
return chats


def pairwise(iterable: List[Any]) -> Iterator[tuple[Any, Any]]:
"""Creates an iterable with tuples paired together
e.g s -> (s0, s1), (s2, s3), (s4, s5), ...
Expand Down Expand Up @@ -81,11 +83,9 @@ def generate_multiturn_history(df: pd.DataFrame) -> pd.DataFrame:
message = {
"human_message": human_message,
"ai_message": ai_message,
"conversation_history": conversation_history
"conversation_history": conversation_history,
}
conversation_history = conversation_history + [
human_message, ai_message
]
conversation_history = conversation_history + [human_message, ai_message]
processed_messages.append(message)

return pd.DataFrame(processed_messages)
Expand All @@ -103,7 +103,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
Args:
row (tuple[int, Dict[str, Any]]): A tuple containing the index and a dictionary
with message data, including:
- "conversation_history" (List[str]): Optional. List of previous
- "conversation_history" (List[str]): Optional. List of previous
messages
in the conversation.
- "human_message" (str): The current human message.
Expand All @@ -118,7 +118,9 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
- "response_obj" (Any): The usage metadata of the response from the callable.
"""
index, message = row
messages = message["conversation_history"] if "conversation_history" in message else []
messages = (
message["conversation_history"] if "conversation_history" in message else []
)
messages.append(message["human_message"])
input_callable = {"messages": messages}
response = callable.invoke(input_callable)
Expand All @@ -130,7 +132,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str
def batch_generate_messages(
messages: pd.DataFrame,
callable: Callable[[List[Dict[str, Any]]], Dict[str, Any]],
max_workers: int = 4
max_workers: int = 4,
) -> pd.DataFrame:
"""Generates AI responses to user messages using a provided callable.
Expand All @@ -152,7 +154,7 @@ def batch_generate_messages(
]
```
callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object
callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object
(e.g., Langchain Chain) used

Check warning on line 158 in gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Langchain` matches a line_forbidden.patterns entry: `\b(?!LangChain\b)(?!langchain\b)[Ll]ang\s?[Cc]hain?\b`. (forbidden-pattern)
for response generation. It should accept a list of message dictionaries
(as described above) and return a dictionary with the following structure:
Expand Down Expand Up @@ -202,6 +204,7 @@ def batch_generate_messages(
predicted_messages.append(message)
return pd.DataFrame(predicted_messages)


def save_df_to_csv(df: pd.DataFrame, dir_path: str, filename: str) -> None:
"""Saves a pandas DataFrame to directory as a CSV file.
Expand Down Expand Up @@ -233,7 +236,9 @@ def prepare_metrics(metrics: List[str]) -> List[Any]:
*module_path, metric_name = metric.removeprefix("custom:").split(".")
metrics_evaluation.append(
__import__(".".join(module_path), fromlist=[metric_name]).__dict__[
metric_name])
metric_name
]
)
else:
metrics_evaluation.append(metric)
return metrics_evaluation
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import logging
from typing import Any, Dict, Iterator

import google
import vertexai
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template
from app.patterns.custom_rag_qa.vector_store import get_vector_store
from app.utils.output_types import OnChatModelStreamEvent, OnToolEndEvent, custom_chain
import google
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
import vertexai

# Configuration
EMBEDDING_MODEL = "text-embedding-004"
Expand Down Expand Up @@ -52,7 +51,9 @@


@custom_chain
def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:
def chain(
input: Dict[str, Any], **kwargs: Any
) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:
"""
Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events

Check warning on line 58 in gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Langchain` matches a line_forbidden.patterns entry: `\b(?!LangChain\b)(?!langchain\b)[Ll]ang\s?[Cc]hain?\b`. (forbidden-pattern)
and invoke interface and OpenTelemetry tracing.
Expand All @@ -69,6 +70,6 @@ def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnC

# Stream LLM response
for chunk in response_chain.stream(
input={"messages": input["messages"], "relevant_documents": ranked_docs}
input={"messages": input["messages"], "relevant_documents": ranked_docs}
):
yield OnChatModelStreamEvent(data={"chunk": chunk})
yield OnChatModelStreamEvent(data={"chunk": chunk})
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

query_rewrite_template = ChatPromptTemplate.from_messages([
("system", "Rewrite a query to a semantic search engine using the current conversation. "
"Provide only the rewritten query as output."),
MessagesPlaceholder(variable_name="messages")
])
query_rewrite_template = ChatPromptTemplate.from_messages(
[
(
"system",
"Rewrite a query to a semantic search engine using the current conversation. "
"Provide only the rewritten query as output.",
),
MessagesPlaceholder(variable_name="messages"),
]
)

rag_template = ChatPromptTemplate.from_messages([
("system", """You are an AI assistant for question-answering tasks. Follow these guidelines:
rag_template = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an AI assistant for question-answering tasks. Follow these guidelines:
1. Use only the provided context to answer the question.
2. Give clear, accurate responses based on the information available.
3. If the context is insufficient, state: "I don't have enough information to answer this question."
Expand All @@ -39,6 +47,9 @@
{{ doc.page_content | safe }}
</Document {{ loop.index0 }}>
{% endfor %}
"""),
MessagesPlaceholder(variable_name="messages")
], template_format="jinja2")
""",
),
MessagesPlaceholder(variable_name="messages"),
],
template_format="jinja2",
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def load_and_split_documents(url: str) -> List[Document]:


def get_vector_store(
embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL
) -> SKLearnVectorStore:
embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL
) -> SKLearnVectorStore:
"""Get or create a vector store."""
vector_store = SKLearnVectorStore(embedding=embedding, persist_path=persist_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,30 @@ def search(query: str) -> str:
return "It's 60 degrees and foggy."
return "It's 90 degrees and sunny."


tools = [search]

# 2. Set up the language model
llm = ChatVertexAI(
model="gemini-1.5-pro-001",
temperature=0,
max_tokens=1024,
streaming=True
model="gemini-1.5-pro-001", temperature=0, max_tokens=1024, streaming=True
).bind_tools(tools)


# 3. Define workflow components
def should_continue(state: MessagesState) -> str:
"""Determines whether to use tools or end the conversation."""
last_message = state['messages'][-1]
return "tools" if last_message.tool_calls else END # type: ignore[union-attr]
last_message = state["messages"][-1]
return "tools" if last_message.tool_calls else END # type: ignore[union-attr]

async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, BaseMessage]:

async def call_model(
state: MessagesState, config: RunnableConfig
) -> Dict[str, BaseMessage]:
"""Calls the language model and returns the response."""
response = llm.invoke(state['messages'], config)
response = llm.invoke(state["messages"], config)
return {"messages": response}


# 4. Create the workflow graph
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
Expand All @@ -59,7 +62,7 @@ async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str,

# 5. Define graph edges
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", 'agent')
workflow.add_edge("tools", "agent")

# 6. Compile the workflow
chain = workflow.compile()
chain = workflow.compile()
53 changes: 29 additions & 24 deletions gemini/sample-apps/conversational-genai-app-template/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@
import json
import logging
import os
import uuid
from typing import AsyncGenerator
import uuid

# ruff: noqa: I001
## Import the chain to be used
from app.chain import chain
from app.utils.input_types import Feedback, Input, InputChat, default_serialization
from app.utils.output_types import EndEvent, Event
from app.utils.tracing import CloudTraceLoggingSpanExporter
from fastapi import FastAPI
from fastapi.responses import RedirectResponse, StreamingResponse
from google.cloud import logging as gcp_logging
from traceloop.sdk import Instruments, Traceloop

from app.utils.input_types import Feedback, Input, InputChat, default_serialization
from app.utils.output_types import EndEvent, Event
from app.utils.tracing import CloudTraceLoggingSpanExporter

# ruff: noqa: I001
## Import the chain to be used
from app.chain import chain

# Or choose one of the following pattern chains to test by uncommenting it:

# Custom RAG QA
Expand All @@ -40,8 +38,13 @@
# from app.patterns.langgraph_dummy_agent.chain import chain

# The events that are supported by the UI Frontend
SUPPORTED_EVENTS = ["on_tool_start", "on_tool_end", "on_retriever_start",
"on_retriever_end", "on_chat_model_stream"]
SUPPORTED_EVENTS = [
"on_tool_start",
"on_tool_end",
"on_retriever_start",
"on_retriever_end",
"on_chat_model_stream",
]

# Initialize FastAPI app and logging
app = FastAPI()
Expand All @@ -63,19 +66,20 @@
async def stream_event_response(input_chat: InputChat) -> AsyncGenerator[str, None]:
run_id = uuid.uuid4()
input_dict = input_chat.model_dump()
Traceloop.set_association_properties({
"log_type": "tracing",
"run_id": str(run_id),
"user_id": input_dict["user_id"],
"session_id": input_dict["session_id"],
"commit_sha": os.environ.get("COMMIT_SHA", "None")
})
Traceloop.set_association_properties(
{
"log_type": "tracing",
"run_id": str(run_id),
"user_id": input_dict["user_id"],
"session_id": input_dict["session_id"],
"commit_sha": os.environ.get("COMMIT_SHA", "None"),
}
)

yield json.dumps(
Event(
event="metadata",
data={"run_id": str(run_id)}
), default=default_serialization) + "\n"
Event(event="metadata", data={"run_id": str(run_id)}),
default=default_serialization,
) + "\n"

async for data in chain.astream_events(input_dict, version="v2"):
if data["event"] in SUPPORTED_EVENTS:
Expand All @@ -97,8 +101,9 @@ async def collect_feedback(feedback_dict: Feedback) -> None:

@app.post("/stream_events")
async def stream_chat_events(request: Input) -> StreamingResponse:
return StreamingResponse(stream_event_response(input_chat=request.input),
media_type="text/event-stream")
return StreamingResponse(
stream_event_response(input_chat=request.input), media_type="text/event-stream"
)


# Main execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@

class InputChat(BaseModel):
"""Represents the input for a chat session."""

messages: List[Union[HumanMessage, AIMessage]] = Field(
...,
description="The chat messages representing the current conversation."
..., description="The chat messages representing the current conversation."
)
user_id: str = ""
session_id: str = ""


class Input(BaseModel):
"""Wrapper class for InputChat."""

input: InputChat


class Feedback(BaseModel):
"""Represents feedback for a conversation."""

score: Union[int, float]
text: Optional[str] = None
run_id: str
log_type: Literal['feedback'] = 'feedback'

log_type: Literal["feedback"] = "feedback"


def default_serialization(obj: Any) -> Any:
Expand Down
Loading

0 comments on commit 64a9ba1

Please sign in to comment.