Skip to content

Commit

Permalink
fix tool calling custom rag pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Oct 7, 2024
1 parent 0efc4e5 commit 553142d
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 178 deletions.
8 changes: 3 additions & 5 deletions gemini/sample-apps/e2e-gen-ai-app-starter-kit/app/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ This folder implements a chatbot application using FastAPI, and Google Cloud ser

### 1. Default Chain

The default chain is a simple conversational bot that produces recipes based on user questions. It uses the Gemini 1.5 Flash model.
The default chain is a simple conversational bot that produces recipes based on user questions.

### 2. Custom RAG QA

A pythonic RAG (Retrieval-Augmented Generation) chain designed for maximum flexibility in orchestrating different components. It includes:
A RAG (Retrieval-Augmented Generation) chain using Python for orchestration and base LangChain components. The chain demonstrates how to create a production-grade application with full control over the orchestration process.

- Query rewriting
- Document retrieval and ranking
- LLM-based response generation
This approach offers maximum flexibility in the orchestration of steps and allows for seamless integration with other SDK frameworks such as [Vertex AI SDK](https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk) and [LangChain](https://python.langchain.com/), retaining the support to emit `astream_events` [API compatible events](https://python.langchain.com/docs/how_to/streaming/#using-stream-events).

### 3. LangGraph Dummy Agent

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# mypy: disable-error-code="arg-type"
# mypy: disable-error-code="arg-type,attr-defined"
import logging
from typing import Any, Dict, Iterator
from typing import Any, Dict, Iterator, List

import google
import vertexai
from langchain.schema import Document
from langchain.tools import tool
from langchain_core.messages import ToolMessage
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template
from app.patterns.custom_rag_qa.templates import (
inspect_conversation_template,
rag_template,
template_docs,
)
from app.patterns.custom_rag_qa.vector_store import get_vector_store
from app.utils.input_types import extract_human_ai_messages
from app.utils.output_types import (
OnChatModelStreamEvent,
OnToolEndEvent,
create_on_tool_end_event_from_retrieval,
custom_chain,
)

Expand All @@ -39,21 +44,59 @@
# Initialize logging
logging.basicConfig(level=logging.INFO)

# Initialize Google Cloud and Vertex AI
credentials, project_id = google.auth.default()
vertexai.init(project=project_id)

# Set up embedding model and vector store
embedding = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)
vector_store = get_vector_store(embedding=embedding)
retriever = vector_store.as_retriever(search_kwargs={"k": 20})

# Initialize document compressor
compressor = VertexAIRank(
project_id=project_id,
location_id="global",
ranking_config="default_ranking_config",
title_field="id",
top_n=TOP_K,
)


@tool
def retrieve_docs(query: str) -> List[Document]:
"""
Useful for retrieving relevant documents based on a query.
Use this when you need additional information to answer a question.
Args:
query (str): The user's question or search query.
Returns:
List[Document]: A list of the top-ranked Document objects, limited to TOP_K (5) results.
"""
retrieved_docs = retriever.invoke(query)
ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)
return ranked_docs


@tool
def should_continue() -> None:
"""
Use this tool if you determine that you have enough context to respond to the questions of the user.
"""
return None


# Initialize language model
llm = ChatVertexAI(model=LLM_MODEL, temperature=0, max_tokens=1024)

query_gen = query_rewrite_template | llm
# Set up conversation inspector
inspect_conversation = inspect_conversation_template | llm.bind_tools(
[retrieve_docs, should_continue], tool_choice="any"
)

# Set up response chain
response_chain = rag_template | llm


Expand All @@ -62,24 +105,41 @@ def chain(
input: Dict[str, Any], **kwargs: Any
) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:
"""
Implements a RAG QA chain. Decorated with `custom_chain` to offer LangChain compatible astream_events
and invoke interface and OpenTelemetry tracing.
Implement a RAG QA chain with tool calls.
This function is decorated with `custom_chain` to offer LangChain compatible
astream_events, support for synchronous invocation through the `invoke` method,
and OpenTelemetry tracing.
"""
# Separate conversation messages from tool calls
input["messages"] = extract_human_ai_messages(input["messages"])
# Inspect conversation and determine next action
inspection_result = inspect_conversation.invoke(input)
tool_call_result = inspection_result.tool_calls[0]

# Generate optimized query
query = query_gen.invoke(input).content
# Execute the appropriate tool based on the inspection result
if tool_call_result["name"] == "retrieve_docs":
# Retrieve relevant documents
docs = retrieve_docs.invoke(tool_call_result["args"])
# Format the retrieved documents
formatted_docs = template_docs.format(docs=docs)
# Create a ToolMessage with the formatted documents
tool_message = ToolMessage(
tool_call_id=tool_call_result["name"],
name=tool_call_result["name"],
content=formatted_docs,
artifact=docs,
)
else:
# If no documents need to be retrieved, continue with the conversation
tool_message = should_continue.invoke(tool_call_result)

# Retrieve and rank documents
retrieved_docs = retriever.invoke(query)
ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)
# Update input messages with new information
input["messages"] = input["messages"] + [inspection_result, tool_message]

# Yield tool results metadata
yield create_on_tool_end_event_from_retrieval(query=query, docs=ranked_docs)
yield OnToolEndEvent(
data={"input": tool_call_result["args"], "output": tool_message}
)

# Stream LLM response
for chunk in response_chain.stream(
input={"messages": input["messages"], "relevant_documents": ranked_docs}
):
for chunk in response_chain.stream(input=input):
yield OnChatModelStreamEvent(data={"chunk": chunk})
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
template_docs = PromptTemplate.from_template(
"""## Context provided:
{% for doc in docs%}
<Document {{ loop.index0 }}>
{{ doc.page_content | safe }}
</Document {{ loop.index0 }}>
{% endfor %}
""",
template_format="jinja2",
)

query_rewrite_template = ChatPromptTemplate.from_messages(
inspect_conversation_template = ChatPromptTemplate.from_messages(
[
(
"system",
"Rewrite a query to a semantic search engine using the current conversation. "
"Provide only the rewritten query as output.",
"""You are an AI assistant tasked with analyzing the conversation "
and determining the best course of action.""",
),
MessagesPlaceholder(variable_name="messages"),
]
Expand All @@ -29,27 +30,12 @@
[
(
"system",
"""You are an AI assistant for question-answering tasks. Follow these guidelines:
1. Use only the provided context to answer the question.
2. Give clear, accurate responses based on the information available.
3. If the context is insufficient, state: "I don't have enough information to answer this question."
4. Don't make assumptions or add information beyond the given context.
5. Use bullet points or lists for clarity when appropriate.
6. Briefly explain technical terms if necessary.
7. Use quotation marks for direct quotes from the context.
"""You are an AI assistant for question-answering tasks.
Answer to the best of your ability using the context provided.
If you're unsure, it's better to acknowledge limitations than to speculate.
## Context provided:
{% for doc in relevant_documents%}
<Document {{ loop.index0 }}>
{{ doc.page_content | safe }}
</Document {{ loop.index0 }}>
{% endfor %}
""",
),
MessagesPlaceholder(variable_name="messages"),
],
template_format="jinja2",
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Annotated, Any, List, Literal, Optional, Union

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -54,59 +54,3 @@ def default_serialization(obj: Any) -> Any:
"""
if isinstance(obj, BaseModel):
return obj.model_dump()


def extract_human_ai_messages(
messages: List[Union[Dict[str, Any], BaseModel]]
) -> List[Dict[str, Any]]:
"""
Extract AI and human messages with non-empty content from a list of messages.
The function will remove all messages relative to tool calls (Empty AI Messages
with tool calls and ToolMessages).
Args:
messages (List[Union[Dict[str, Any], BaseModel]]): A list of message objects.
Returns:
List[Dict[str, Any]]: A list of extracted AI and human messages with
non-empty content.
"""
extracted_messages = []
for message in messages:
if isinstance(message, BaseModel):
message = message.model_dump()

is_valid_type = message.get("type") in ["human", "ai"]
has_content = bool(message.get("content"))

if is_valid_type and has_content:
extracted_messages.append(message)

return extracted_messages


def extract_tool_calls_and_messages(
messages: List[Union[Dict[str, Any], BaseModel]]
) -> List[Dict[str, Any]]:
"""
Extract AI Messages with tool calls and ToolMessages from a list of messages.
AI Messages with tool calls define tool inputs, while ToolMessages contain outputs.
Args:
messages (List[Union[Dict[str, Any], BaseModel]]): A list of message objects.
Returns:
List[Dict[str, Any]]: A list of extracted tool calls and .
"""
extracted_messages = []
for message in messages:
if isinstance(message, BaseModel):
message = message.model_dump()

is_tool_message = message.get("type") == "tool"
is_ai_with_tool_call = message.get("type") == "ai" and message.get("tool_calls")

if is_tool_message or is_ai_with_tool_call:
extracted_messages.append(message)

return extracted_messages
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import uuid
from functools import wraps
from types import GeneratorType
from typing import Any, AsyncGenerator, Callable, Dict, List, Literal
from typing import Any, AsyncGenerator, Callable, Dict, Literal

from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage
from pydantic import BaseModel, Field
from traceloop.sdk import TracerWrapper
Expand Down Expand Up @@ -66,32 +65,6 @@ class EndEvent(BaseModel):
event: Literal["end"] = "end"


def create_on_tool_end_event_from_retrieval(
query: str,
docs: List[Document],
tool_call_id: str = "retriever",
name: str = "retriever",
) -> OnToolEndEvent:
"""
Create a LangChain Astream events v2 compatible on_tool_end_event from a retrieval process.
Args:
query (str): The query used for retrieval.
docs (List[Document]): The retrieved documents.
tool_call_id (str, optional): The ID of the tool call. Defaults to "retriever".
name (str, optional): The name of the tool. Defaults to "retriever".
Returns:
OnToolEndEvent: An event representing the end of the retrieval tool execution.
"""
ranked_docs_tool_output = ToolMessage(
tool_call_id=tool_call_id, content=[doc.model_dump() for doc in docs], name=name
)
return OnToolEndEvent(
data=ToolData(input={"query": query}, output=ranked_docs_tool_output)
)


class CustomChain:
"""A custom chain class that wraps a callable function."""

Expand All @@ -111,6 +84,7 @@ async def astream_events(self, *args: Any, **kwargs: Any) -> AsyncGenerator:
func = self.func

gen: GeneratorType = func(*args, **kwargs)

for event in gen:
yield event.model_dump()

Expand Down
Loading

0 comments on commit 553142d

Please sign in to comment.