From fe208e28d938cbbdd0fb3cac91a6b1768da19e60 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Thu, 31 Oct 2024 13:25:39 -0500 Subject: [PATCH 1/2] fix: Update `MatchingEngineIndex` to be created with `STREAM_UPDATE` --- .../multimodal_rag_langchain.ipynb | 2017 +++++++++-------- 1 file changed, 1011 insertions(+), 1006 deletions(-) diff --git a/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb b/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb index b01ff5f2acb..6d951d184d3 100644 --- a/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb +++ b/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb @@ -1,1008 +1,1013 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ijGzTHJJUCPY" - }, - "outputs": [], - "source": [ - "# Copyright 2024 Google LLC\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NDsTUvKjwHBW" - }, - "source": [ - "# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \n", - " \"Google
Run in Colab Enterprise\n", - "
\n", - "
\n", - " \n", - " \"Google
Run in Colab\n", - "
\n", - "
\n", - " \n", - " \"Vertex
Open in Vertex AI Workbench\n", - "
\n", - "
\n", - " \n", - " \"GitHub
View on GitHub\n", - "
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "37454218170d" - }, - "source": [ - "| | | \n", - "|-|-|\n", - "|Author(s) | [Holt Skinner](https://github.com/holtskinner) |" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VK1Q5ZYdVL4Y" - }, - "source": [ - "## Overview\n", - "\n", - "Retrieval augmented generation (RAG) has become a popular paradigm for enabling LLMs to access external data and also as a mechanism for grounding to mitigate against hallucinations.\n", - "\n", - "In this notebook, you will learn how to perform multimodal RAG where you will perform Q&A over a financial document filled with both text and images.\n", - "\n", - "### Gemini\n", - "\n", - "Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the Gemini 1.0 Pro Vision and Gemini 1.0 Pro models.\n", - "\n", - "### Comparing text-based and multimodal RAG\n", - "\n", - "Multimodal RAG offers several advantages over text-based RAG:\n", - "\n", - "1. **Enhanced knowledge access:** Multimodal RAG can access and process both textual and visual information, providing a richer and more comprehensive knowledge base for the LLM.\n", - "2. **Improved reasoning capabilities:** By incorporating visual cues, multimodal RAG can make better informed inferences across different types of data modalities.\n", - "\n", - "This notebook shows you how to use RAG with Vertex AI Gemini API, and [multimodal embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/multimodal-embeddings), to build a document search engine.\n", - "\n", - "Through hands-on examples, you will discover how to construct a multimedia-rich metadata repository of your document sources, enabling search, comparison, and reasoning across diverse information streams." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RQT500QqVPIb" - }, - "source": [ - "### Objectives\n", - "\n", - "This notebook provides a guide to building a document search engine using multimodal retrieval augmented generation (RAG), step by step:\n", - "\n", - "1. Extract and store metadata of documents containing both text and images, and generate embeddings the documents\n", - "2. Search the metadata with text queries to find similar text or images\n", - "3. Search the metadata with image queries to find similar images\n", - "4. Using a text query as input, search for contextual answers using both text and images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KnpYxfesh2rI" - }, - "source": [ - "### Costs\n", - "\n", - "This tutorial uses billable components of Google Cloud:\n", - "\n", - "- Vertex AI\n", - "\n", - "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DXJpXzKrh2rJ" - }, - "source": [ - "## Getting Started" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N5afkyDMSBW5" - }, - "source": [ - "### Install Vertex AI SDK for Python and other dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kc4WxYmLSBW5" - }, - "outputs": [], - "source": [ - "%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community \"unstructured[all-docs]\" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R5Xep4W9lq-Z" - }, - "source": [ - "### Restart current runtime\n", - "\n", - "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XRvKdaPDTznN" - }, - "outputs": [], - "source": [ - "# Restart kernel after installs so that your environment can access the new packages\n", - "import IPython\n", - "\n", - "app = IPython.Application.instance()\n", - "app.kernel.do_shutdown(True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SbmM4z7FOBpM" - }, - "source": [ - "
\n", - "⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FtsU9Bw9h2rL" - }, - "source": [ - "### Authenticate your notebook environment (Colab only)\n", - "\n", - "If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GpYEyLsOh2rL" - }, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "# Additional authentication is required for Google Colab\n", - "if \"google.colab\" in sys.modules:\n", - " # Authenticate user to Google Cloud\n", - " from google.colab import auth\n", - "\n", - " auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O1vKZZoEh2rL" - }, - "source": [ - "### Define Google Cloud project information" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gJqZ76rJh2rM" - }, - "outputs": [], - "source": [ - "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", - "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", - "\n", - "# For Vector Search Staging\n", - "GCS_BUCKET = \"YOUR_BUCKET_NAME\" # @param {type:\"string\"}\n", - "GCS_BUCKET_URI = f\"gs://{GCS_BUCKET}\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "57262621bd1c" - }, - "source": [ - "### Initialize the Vertex AI SDK" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D48gUW5-h2rM" - }, - "outputs": [], - "source": [ - "from google.cloud import aiplatform\n", - "\n", - "aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BuQwwRiniVFG" - }, - "source": [ - "### Import libraries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rtMowvm-yQ97" - }, - "outputs": [], - "source": [ - "import base64\n", - "import os\n", - "import re\n", - "import uuid\n", - "\n", - "from IPython.display import Image, Markdown, display\n", - "from langchain.prompts import PromptTemplate\n", - "from langchain.retrievers.multi_vector import MultiVectorRetriever\n", - "from langchain.storage import InMemoryStore\n", - "from langchain_core.documents import Document\n", - "from langchain_core.messages import AIMessage, HumanMessage\n", - "from langchain_core.output_parsers import StrOutputParser\n", - "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", - "from langchain_google_vertexai import (\n", - " ChatVertexAI,\n", - " VectorSearchVectorStore,\n", - " VertexAI,\n", - " VertexAIEmbeddings,\n", - ")\n", - "from langchain_text_splitters import CharacterTextSplitter\n", - "from unstructured.partition.pdf import partition_pdf\n", - "\n", - "# from langchain_community.vectorstores import Chroma # Optional" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2bf3ee5d1686" - }, - "source": [ - "### Define model information\n", - "\n", - "- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "eb39bdada39d" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"gemini-1.5-flash\"\n", - "GEMINI_OUTPUT_TOKEN_LIMIT = 8192\n", - "\n", - "EMBEDDING_MODEL_NAME = \"text-embedding-004\"\n", - "EMBEDDING_TOKEN_LIMIT = 2048\n", - "\n", - "TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2c919bd5a462" - }, - "source": [ - "## Data Loading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g7bKCQMFT7JT" - }, - "source": [ - "#### Get documents and images from GCS" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KwbL89zcY39N" - }, - "outputs": [], - "source": [ - "# Download documents and images used in this notebook\n", - "!gsutil -m rsync -r gs://github-repo/rag/intro_multimodal_rag/ .\n", - "print(\"Download completed\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ps1G-cCfpibN" - }, - "source": [ - "## Partition PDF tables, text, and images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jqLsy3iZ5t-R" - }, - "source": [ - "### The data\n", - "\n", - "The source data that you will use in this notebook is a modified version of [Google-10K](https://abc.xyz/assets/investor/static/pdf/20220202_alphabet_10K.pdf) which provides a comprehensive overview of the company's financial performance, business operations, management, and risk factors. As the original document is rather large, you will be using [a modified version with only 14 pages](https://storage.googleapis.com/github-repo/rag/multimodal_rag_langchain/google-10k-sample-14pages.pdf) instead. Although it's truncated, the sample document still contains text along with images such as tables, charts, and graphs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3a87cb1a097b" - }, - "outputs": [], - "source": [ - "pdf_folder_path = \"/content/data/\" if \"google.colab\" in sys.modules else \"data/\"\n", - "pdf_file_name = \"google-10k-sample-14pages.pdf\"\n", - "\n", - "# Extract images, tables, and chunk text from a PDF file.\n", - "raw_pdf_elements = partition_pdf(\n", - " filename=pdf_file_name,\n", - " extract_images_in_pdf=False,\n", - " infer_table_structure=True,\n", - " chunking_strategy=\"by_title\",\n", - " max_characters=4000,\n", - " new_after_n_chars=3800,\n", - " combine_text_under_n_chars=2000,\n", - " image_output_dir_path=pdf_folder_path,\n", - ")\n", - "\n", - "# Categorize extracted elements from a PDF into tables and texts.\n", - "tables = []\n", - "texts = []\n", - "for element in raw_pdf_elements:\n", - " if \"unstructured.documents.elements.Table\" in str(type(element)):\n", - " tables.append(str(element))\n", - " elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n", - " texts.append(str(element))\n", - "\n", - "# Optional: Enforce a specific token size for texts\n", - "text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n", - " chunk_size=10000, chunk_overlap=0\n", - ")\n", - "joined_texts = \" \".join(texts)\n", - "texts_4k_token = text_splitter.split_text(joined_texts)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "244963a30836" - }, - "outputs": [], - "source": [ - "# Generate summaries of text elements\n", - "\n", - "\n", - "def generate_text_summaries(\n", - " texts: list[str], tables: list[str], summarize_texts: bool = False\n", - ") -> tuple[list, list]:\n", - " \"\"\"\n", - " Summarize text elements\n", - " texts: List of str\n", - " tables: List of str\n", - " summarize_texts: Bool to summarize texts\n", - " \"\"\"\n", - "\n", - " # Prompt\n", - " prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n", - " These summaries will be embedded and used to retrieve the raw text or table elements. \\\n", - " Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n", - " prompt = PromptTemplate.from_template(prompt_text)\n", - " empty_response = RunnableLambda(\n", - " lambda x: AIMessage(content=\"Error processing document\")\n", - " )\n", - " # Text summary chain\n", - " model = VertexAI(\n", - " temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT\n", - " ).with_fallbacks([empty_response])\n", - " summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n", - "\n", - " # Initialize empty summaries\n", - " text_summaries = []\n", - " table_summaries = []\n", - "\n", - " # Apply to text if texts are provided and summarization is requested\n", - " if texts:\n", - " if summarize_texts:\n", - " text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n", - " else:\n", - " text_summaries = texts\n", - "\n", - " # Apply to tables if tables are provided\n", - " if tables:\n", - " table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n", - "\n", - " return text_summaries, table_summaries\n", - "\n", - "\n", - "# Get text, table summaries\n", - "text_summaries, table_summaries = generate_text_summaries(\n", - " texts_4k_token, tables, summarize_texts=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "379ae4ffbf83" - }, - "outputs": [], - "source": [ - "def encode_image(image_path: str) -> str:\n", - " \"\"\"Getting the base64 string\"\"\"\n", - " with open(image_path, \"rb\") as image_file:\n", - " return base64.b64encode(image_file.read()).decode(\"utf-8\")\n", - "\n", - "\n", - "def image_summarize(model: ChatVertexAI, base64_image: str, prompt: str) -> str:\n", - " \"\"\"Make image summary\"\"\"\n", - " msg = model.invoke(\n", - " [\n", - " HumanMessage(\n", - " content=[\n", - " {\"type\": \"text\", \"text\": prompt},\n", - " {\n", - " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": f\"data:image/png;base64,{base64_image}\"},\n", - " },\n", - " ]\n", - " )\n", - " ]\n", - " )\n", - " return msg.content\n", - "\n", - "\n", - "def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:\n", - " \"\"\"\n", - " Generate summaries and base64 encoded strings for images\n", - " path: Path to list of .jpg files extracted by Unstructured\n", - " \"\"\"\n", - "\n", - " # Store base64 encoded images\n", - " img_base64_list = []\n", - "\n", - " # Store image summaries\n", - " image_summaries = []\n", - "\n", - " # Prompt\n", - " prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n", - " These summaries will be embedded and used to retrieve the raw image. \\\n", - " Give a concise summary of the image that is well optimized for retrieval.\n", - " If it's a table, extract all elements of the table.\n", - " If it's a graph, explain the findings in the graph.\n", - " Do not include any numbers that are not mentioned in the image.\n", - " \"\"\"\n", - "\n", - " model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)\n", - "\n", - " # Apply to images\n", - " for img_file in sorted(os.listdir(path)):\n", - " if img_file.endswith(\".png\"):\n", - " base64_image = encode_image(os.path.join(path, img_file))\n", - " img_base64_list.append(base64_image)\n", - " image_summaries.append(image_summarize(model, base64_image, prompt))\n", - "\n", - " return img_base64_list, image_summaries\n", - "\n", - "\n", - "# Image summaries\n", - "img_base64_list, image_summaries = generate_img_summaries(\".\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b641a76265d0" - }, - "source": [ - "## Create & Deploy Vertex AI Vector Search Index & Endpoint\n", - "\n", - "Skip this step if you already have Vector Search set up.\n", - "\n", - "- https://console.cloud.google.com/vertex-ai/matching-engine/indexes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c15693534ed1" - }, - "source": [ - "- Create [`MatchingEngineIndex`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex)\n", - " - https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dad379accb68" - }, - "outputs": [], - "source": [ - "# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings\n", - "DIMENSIONS = 768 # Dimensions output from textembedding-gecko\n", - "\n", - "index = aiplatform.MatchingEngineIndex.create_tree_ah_index(\n", - " display_name=\"mm_rag_langchain_index\",\n", - " dimensions=DIMENSIONS,\n", - " approximate_neighbors_count=150,\n", - " leaf_node_embedding_count=500,\n", - " leaf_nodes_to_search_percent=7,\n", - " description=\"Multimodal RAG LangChain Index\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "add71035aaa1" - }, - "source": [ - "- Create [`MatchingEngineIndexEndpoint`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndexEndpoint)\n", - " - https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "140c0142b90f" - }, - "outputs": [], - "source": [ - "DEPLOYED_INDEX_ID = \"mm_rag_langchain_index_endpoint\"\n", - "\n", - "index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n", - " display_name=DEPLOYED_INDEX_ID,\n", - " description=\"Multimodal RAG LangChain Index Endpoint\",\n", - " public_endpoint_enabled=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b6adda75cab6" - }, - "source": [ - "- Deploy Index to Index Endpoint\n", - " - NOTE: This will take a while to run.\n", - " - You can stop this cell after starting it instead of waiting for deployment.\n", - " - You can check the status at https://console.cloud.google.com/vertex-ai/matching-engine/indexes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4a02468a018b" - }, - "outputs": [], - "source": [ - "index_endpoint = index_endpoint.deploy_index(\n", - " index=index, deployed_index_id=\"mm_rag_langchain_deployed_index\"\n", - ")\n", - "index_endpoint.deployed_indexes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bd8475f61ef9" - }, - "source": [ - "## Create retriever & load documents" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "487ad4e4ccac" - }, - "source": [ - "- Create [`VectorSearchVectorStore`](https://api.python.langchain.com/en/latest/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.\n", - "- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e49355d04889" - }, - "outputs": [], - "source": [ - "# The vectorstore to use to index the summaries\n", - "vectorstore = VectorSearchVectorStore.from_components(\n", - " project_id=PROJECT_ID,\n", - " region=LOCATION,\n", - " gcs_bucket_name=GCS_BUCKET,\n", - " index_id=index.name,\n", - " endpoint_id=index_endpoint.name,\n", - " embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", - " stream_update=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "26ef209ff8ba" - }, - "source": [ - "- Alternatively, use Chroma for a local vector store." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1b7f713b2607" - }, - "outputs": [], - "source": [ - "# vectorstore = Chroma(\n", - "# collection_name=\"mm_rag_test\",\n", - "# embedding_function=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", - "# )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "67a4b0490b45" - }, - "source": [ - "- Create Multi-Vector Retriever using the vector store you created.\n", - "- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8e92ff890483" - }, - "outputs": [], - "source": [ - "docstore = InMemoryStore()\n", - "\n", - "id_key = \"doc_id\"\n", - "# Create the multi-vector retriever\n", - "retriever_multi_vector_img = MultiVectorRetriever(\n", - " vectorstore=vectorstore,\n", - " docstore=docstore,\n", - " id_key=id_key,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "96b37cf7dc47" - }, - "source": [ - "- Load data into Document Store and Vector Store" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0a92a4b04319" - }, - "outputs": [], - "source": [ - "# Raw Document Contents\n", - "doc_contents = texts + tables + img_base64_list\n", - "\n", - "doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n", - "summary_docs = [\n", - " Document(page_content=s, metadata={id_key: doc_ids[i]})\n", - " for i, s in enumerate(text_summaries + table_summaries + image_summaries)\n", - "]\n", - "\n", - "retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, doc_contents)))\n", - "\n", - "# If using Vertex AI Vector Search, this will take a while to complete.\n", - "# You can cancel this cell and continue later.\n", - "retriever_multi_vector_img.vectorstore.add_documents(summary_docs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b665ead18f3b" - }, - "source": [ - "## Create Chain with Retriever and Gemini LLM" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5228d5831d34" - }, - "outputs": [], - "source": [ - "def looks_like_base64(sb):\n", - " \"\"\"Check if the string looks like base64\"\"\"\n", - " return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n", - "\n", - "\n", - "def is_image_data(b64data):\n", - " \"\"\"\n", - " Check if the base64 data is an image by looking at the start of the data\n", - " \"\"\"\n", - " image_signatures = {\n", - " b\"\\xFF\\xD8\\xFF\": \"jpg\",\n", - " b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n", - " b\"\\x47\\x49\\x46\\x38\": \"gif\",\n", - " b\"\\x52\\x49\\x46\\x46\": \"webp\",\n", - " }\n", - " try:\n", - " header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n", - " for sig, format in image_signatures.items():\n", - " if header.startswith(sig):\n", - " return True\n", - " return False\n", - " except Exception:\n", - " return False\n", - "\n", - "\n", - "def split_image_text_types(docs):\n", - " \"\"\"\n", - " Split base64-encoded images and texts\n", - " \"\"\"\n", - " b64_images = []\n", - " texts = []\n", - " for doc in docs:\n", - " # Check if the document is of type Document and extract page_content if so\n", - " if isinstance(doc, Document):\n", - " doc = doc.page_content\n", - " if looks_like_base64(doc) and is_image_data(doc):\n", - " b64_images.append(doc)\n", - " else:\n", - " texts.append(doc)\n", - " return {\"images\": b64_images, \"texts\": texts}\n", - "\n", - "\n", - "def img_prompt_func(data_dict):\n", - " \"\"\"\n", - " Join the context into a single string\n", - " \"\"\"\n", - " formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n", - " messages = [\n", - " {\n", - " \"type\": \"text\",\n", - " \"text\": (\n", - " \"You are financial analyst tasking with providing investment advice.\\n\"\n", - " \"You will be given a mix of text, tables, and image(s) usually of charts or graphs.\\n\"\n", - " \"Use this information to provide investment advice related to the user's question. \\n\"\n", - " f\"User-provided question: {data_dict['question']}\\n\\n\"\n", - " \"Text and / or tables:\\n\"\n", - " f\"{formatted_texts}\"\n", - " ),\n", - " }\n", - " ]\n", - "\n", - " # Adding image(s) to the messages if present\n", - " if data_dict[\"context\"][\"images\"]:\n", - " for image in data_dict[\"context\"][\"images\"]:\n", - " messages.append(\n", - " {\n", - " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n", - " }\n", - " )\n", - " return [HumanMessage(content=messages)]\n", - "\n", - "\n", - "# Create RAG chain\n", - "chain_multimodal_rag = (\n", - " {\n", - " \"context\": retriever_multi_vector_img | RunnableLambda(split_image_text_types),\n", - " \"question\": RunnablePassthrough(),\n", - " }\n", - " | RunnableLambda(img_prompt_func)\n", - " | ChatVertexAI(\n", - " temperature=0,\n", - " model_name=MODEL_NAME,\n", - " max_output_tokens=TOKEN_LIMIT,\n", - " ) # Multi-modal LLM\n", - " | StrOutputParser()\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2925d397fdbb" - }, - "source": [ - "## Process user query" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "b3b445f934a8" - }, - "outputs": [], - "source": [ - "query = \"\"\"\n", - " - What are the critical difference between various graphs for Class A Share?\n", - " - Which index best matches Class A share performance closely where Google is not already a part? Explain the reasoning.\n", - " - Identify key chart patterns for Google Class A shares.\n", - " - What is cost of revenues, operating expenses and net income for 2020. Do mention the percentage change\n", - " - What was the effect of Covid in the 2020 financial year?\n", - " - What are the total revenues for APAC and USA for 2021?\n", - " - What is deferred income taxes?\n", - " - How do you compute net income per share?\n", - " - What drove percentage change in the consolidated revenue and cost of revenue for the year 2021 and was there any effect of Covid?\n", - " - What is the cause of 41% increase in revenue from 2020 to 2021 and how much is dollar change?\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6172a22b1203" - }, - "source": [ - "### Get Retrieved documents" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "90a34d3712e0" - }, - "outputs": [], - "source": [ - "# List of source documents\n", - "docs = retriever_multi_vector_img.get_relevant_documents(query, limit=10)\n", - "\n", - "source_docs = split_image_text_types(docs)\n", - "\n", - "print(source_docs[\"texts\"])\n", - "\n", - "for i in source_docs[\"images\"]:\n", - " display(Image(base64.b64decode(i)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bd784ce7f205" - }, - "source": [ - "### Get generative response" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4c5f936f89da" - }, - "outputs": [], - "source": [ - "result = chain_multimodal_rag.invoke(query)\n", - "\n", - "Markdown(result)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KwNrHCqbi3xi" - }, - "source": [ - "## Conclusions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "05jynhZnkgxn" - }, - "source": [ - "Congratulations on making it through this multimodal RAG notebook!\n", - "\n", - "While multimodal RAG can be quite powerful, note that it can face some limitations:\n", - "\n", - "* **Data dependency:** Needs high-accuracy data from the text and visuals.\n", - "* **Computationally demanding:** Generating embeddings from multimodal data is resource-intensive.\n", - "* **Domain specific:** Models trained on general data may not shine in specialized fields like medicine.\n", - "* **Black box:** Understanding how these models work can be tricky, hindering trust and adoption.\n", - "\n", - "\n", - "Despite these challenges, multimodal RAG represents a significant step towards search and retrieval systems that can handle diverse, multimodal data." - ] - } - ], - "metadata": { - "colab": { - "name": "multimodal_rag_langchain.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ijGzTHJJUCPY" + }, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDsTUvKjwHBW" + }, + "source": [ + "# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \"Google
Run in Colab Enterprise\n", + "
\n", + "
\n", + " \n", + " \"Google
Run in Colab\n", + "
\n", + "
\n", + " \n", + " \"Vertex
Open in Vertex AI Workbench\n", + "
\n", + "
\n", + " \n", + " \"GitHub
View on GitHub\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "37454218170d" + }, + "source": [ + "| | | \n", + "|-|-|\n", + "|Author(s) | [Holt Skinner](https://github.com/holtskinner) |" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VK1Q5ZYdVL4Y" + }, + "source": [ + "## Overview\n", + "\n", + "Retrieval augmented generation (RAG) has become a popular paradigm for enabling LLMs to access external data and also as a mechanism for grounding to mitigate against hallucinations.\n", + "\n", + "In this notebook, you will learn how to perform multimodal RAG where you will perform Q&A over a financial document filled with both text and images.\n", + "\n", + "### Gemini\n", + "\n", + "Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the Gemini 1.0 Pro Vision and Gemini 1.0 Pro models.\n", + "\n", + "### Comparing text-based and multimodal RAG\n", + "\n", + "Multimodal RAG offers several advantages over text-based RAG:\n", + "\n", + "1. **Enhanced knowledge access:** Multimodal RAG can access and process both textual and visual information, providing a richer and more comprehensive knowledge base for the LLM.\n", + "2. **Improved reasoning capabilities:** By incorporating visual cues, multimodal RAG can make better informed inferences across different types of data modalities.\n", + "\n", + "This notebook shows you how to use RAG with Vertex AI Gemini API, and [multimodal embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/multimodal-embeddings), to build a document search engine.\n", + "\n", + "Through hands-on examples, you will discover how to construct a multimedia-rich metadata repository of your document sources, enabling search, comparison, and reasoning across diverse information streams." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RQT500QqVPIb" + }, + "source": [ + "### Objectives\n", + "\n", + "This notebook provides a guide to building a document search engine using multimodal retrieval augmented generation (RAG), step by step:\n", + "\n", + "1. Extract and store metadata of documents containing both text and images, and generate embeddings the documents\n", + "2. Search the metadata with text queries to find similar text or images\n", + "3. Search the metadata with image queries to find similar images\n", + "4. Using a text query as input, search for contextual answers using both text and images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KnpYxfesh2rI" + }, + "source": [ + "### Costs\n", + "\n", + "This tutorial uses billable components of Google Cloud:\n", + "\n", + "- Vertex AI\n", + "\n", + "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DXJpXzKrh2rJ" + }, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N5afkyDMSBW5" + }, + "source": [ + "### Install Vertex AI SDK for Python and other dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kc4WxYmLSBW5" + }, + "outputs": [], + "source": [ + "%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community \"unstructured[all-docs]\" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R5Xep4W9lq-Z" + }, + "source": [ + "### Restart current runtime\n", + "\n", + "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XRvKdaPDTznN" + }, + "outputs": [], + "source": [ + "# Restart kernel after installs so that your environment can access the new packages\n", + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SbmM4z7FOBpM" + }, + "source": [ + "
\n", + "⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FtsU9Bw9h2rL" + }, + "source": [ + "### Authenticate your notebook environment (Colab only)\n", + "\n", + "If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GpYEyLsOh2rL" + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# Additional authentication is required for Google Colab\n", + "if \"google.colab\" in sys.modules:\n", + " # Authenticate user to Google Cloud\n", + " from google.colab import auth\n", + "\n", + " auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O1vKZZoEh2rL" + }, + "source": [ + "### Define Google Cloud project information" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gJqZ76rJh2rM" + }, + "outputs": [], + "source": [ + "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", + "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", + "\n", + "# For Vector Search Staging\n", + "GCS_BUCKET = \"YOUR_BUCKET_NAME\" # @param {type:\"string\"}\n", + "GCS_BUCKET_URI = f\"gs://{GCS_BUCKET}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "57262621bd1c" + }, + "source": [ + "### Initialize the Vertex AI SDK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D48gUW5-h2rM" + }, + "outputs": [], + "source": [ + "from google.cloud import aiplatform\n", + "\n", + "aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BuQwwRiniVFG" + }, + "source": [ + "### Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rtMowvm-yQ97" + }, + "outputs": [], + "source": [ + "import base64\n", + "import os\n", + "import re\n", + "import uuid\n", + "\n", + "from IPython.display import Image, Markdown, display\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.retrievers.multi_vector import MultiVectorRetriever\n", + "from langchain.storage import InMemoryStore\n", + "from langchain_core.documents import Document\n", + "from langchain_core.messages import AIMessage, HumanMessage\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", + "from langchain_google_vertexai import (\n", + " ChatVertexAI,\n", + " VectorSearchVectorStore,\n", + " VertexAI,\n", + " VertexAIEmbeddings,\n", + ")\n", + "from langchain_text_splitters import CharacterTextSplitter\n", + "from unstructured.partition.pdf import partition_pdf\n", + "\n", + "# from langchain_community.vectorstores import Chroma # Optional" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2bf3ee5d1686" + }, + "source": [ + "### Define model information\n", + "\n", + "- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "eb39bdada39d" + }, + "outputs": [], + "source": [ + "MODEL_NAME = \"gemini-1.5-flash\"\n", + "GEMINI_OUTPUT_TOKEN_LIMIT = 8192\n", + "\n", + "EMBEDDING_MODEL_NAME = \"text-embedding-004\"\n", + "EMBEDDING_TOKEN_LIMIT = 2048\n", + "\n", + "TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2c919bd5a462" + }, + "source": [ + "## Data Loading" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g7bKCQMFT7JT" + }, + "source": [ + "#### Get documents and images from GCS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KwbL89zcY39N" + }, + "outputs": [], + "source": [ + "# Download documents and images used in this notebook\n", + "!gsutil -m rsync -r gs://github-repo/rag/intro_multimodal_rag/ .\n", + "print(\"Download completed\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ps1G-cCfpibN" + }, + "source": [ + "## Partition PDF tables, text, and images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jqLsy3iZ5t-R" + }, + "source": [ + "### The data\n", + "\n", + "The source data that you will use in this notebook is a modified version of [Google-10K](https://abc.xyz/assets/investor/static/pdf/20220202_alphabet_10K.pdf) which provides a comprehensive overview of the company's financial performance, business operations, management, and risk factors. As the original document is rather large, you will be using [a modified version with only 14 pages](https://storage.googleapis.com/github-repo/rag/multimodal_rag_langchain/google-10k-sample-14pages.pdf) instead. Although it's truncated, the sample document still contains text along with images such as tables, charts, and graphs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3a87cb1a097b" + }, + "outputs": [], + "source": [ + "pdf_folder_path = \"/content/data/\" if \"google.colab\" in sys.modules else \"data/\"\n", + "pdf_file_name = \"google-10k-sample-14pages.pdf\"\n", + "\n", + "# Extract images, tables, and chunk text from a PDF file.\n", + "raw_pdf_elements = partition_pdf(\n", + " filename=pdf_file_name,\n", + " extract_images_in_pdf=False,\n", + " infer_table_structure=True,\n", + " chunking_strategy=\"by_title\",\n", + " max_characters=4000,\n", + " new_after_n_chars=3800,\n", + " combine_text_under_n_chars=2000,\n", + " image_output_dir_path=pdf_folder_path,\n", + ")\n", + "\n", + "# Categorize extracted elements from a PDF into tables and texts.\n", + "tables = []\n", + "texts = []\n", + "for element in raw_pdf_elements:\n", + " if \"unstructured.documents.elements.Table\" in str(type(element)):\n", + " tables.append(str(element))\n", + " elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n", + " texts.append(str(element))\n", + "\n", + "# Optional: Enforce a specific token size for texts\n", + "text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n", + " chunk_size=10000, chunk_overlap=0\n", + ")\n", + "joined_texts = \" \".join(texts)\n", + "texts_4k_token = text_splitter.split_text(joined_texts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "244963a30836" + }, + "outputs": [], + "source": [ + "# Generate summaries of text elements\n", + "\n", + "\n", + "def generate_text_summaries(\n", + " texts: list[str], tables: list[str], summarize_texts: bool = False\n", + ") -> tuple[list, list]:\n", + " \"\"\"\n", + " Summarize text elements\n", + " texts: List of str\n", + " tables: List of str\n", + " summarize_texts: Bool to summarize texts\n", + " \"\"\"\n", + "\n", + " # Prompt\n", + " prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw text or table elements. \\\n", + " Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n", + " prompt = PromptTemplate.from_template(prompt_text)\n", + " empty_response = RunnableLambda(\n", + " lambda x: AIMessage(content=\"Error processing document\")\n", + " )\n", + " # Text summary chain\n", + " model = VertexAI(\n", + " temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT\n", + " ).with_fallbacks([empty_response])\n", + " summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n", + "\n", + " # Initialize empty summaries\n", + " text_summaries = []\n", + " table_summaries = []\n", + "\n", + " # Apply to text if texts are provided and summarization is requested\n", + " if texts:\n", + " if summarize_texts:\n", + " text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n", + " else:\n", + " text_summaries = texts\n", + "\n", + " # Apply to tables if tables are provided\n", + " if tables:\n", + " table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n", + "\n", + " return text_summaries, table_summaries\n", + "\n", + "\n", + "# Get text, table summaries\n", + "text_summaries, table_summaries = generate_text_summaries(\n", + " texts_4k_token, tables, summarize_texts=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "379ae4ffbf83" + }, + "outputs": [], + "source": [ + "def encode_image(image_path: str) -> str:\n", + " \"\"\"Getting the base64 string\"\"\"\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode(\"utf-8\")\n", + "\n", + "\n", + "def image_summarize(model: ChatVertexAI, base64_image: str, prompt: str) -> str:\n", + " \"\"\"Make image summary\"\"\"\n", + " msg = model.invoke(\n", + " [\n", + " HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": prompt},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/png;base64,{base64_image}\"},\n", + " },\n", + " ]\n", + " )\n", + " ]\n", + " )\n", + " return msg.content\n", + "\n", + "\n", + "def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:\n", + " \"\"\"\n", + " Generate summaries and base64 encoded strings for images\n", + " path: Path to list of .jpg files extracted by Unstructured\n", + " \"\"\"\n", + "\n", + " # Store base64 encoded images\n", + " img_base64_list = []\n", + "\n", + " # Store image summaries\n", + " image_summaries = []\n", + "\n", + " # Prompt\n", + " prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw image. \\\n", + " Give a concise summary of the image that is well optimized for retrieval.\n", + " If it's a table, extract all elements of the table.\n", + " If it's a graph, explain the findings in the graph.\n", + " Do not include any numbers that are not mentioned in the image.\n", + " \"\"\"\n", + "\n", + " model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)\n", + "\n", + " # Apply to images\n", + " for img_file in sorted(os.listdir(path)):\n", + " if img_file.endswith(\".png\"):\n", + " base64_image = encode_image(os.path.join(path, img_file))\n", + " img_base64_list.append(base64_image)\n", + " image_summaries.append(image_summarize(model, base64_image, prompt))\n", + "\n", + " return img_base64_list, image_summaries\n", + "\n", + "\n", + "# Image summaries\n", + "img_base64_list, image_summaries = generate_img_summaries(\".\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b641a76265d0" + }, + "source": [ + "## Create & Deploy Vertex AI Vector Search Index & Endpoint\n", + "\n", + "Skip this step if you already have Vector Search set up.\n", + "\n", + "- https://console.cloud.google.com/vertex-ai/matching-engine/indexes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c15693534ed1" + }, + "source": [ + "- Create [`MatchingEngineIndex`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex)\n", + " - https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dad379accb68" + }, + "outputs": [], + "source": [ + "# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings\n", + "DIMENSIONS = 768 # Dimensions output from textembedding-gecko\n", + "\n", + "index = aiplatform.MatchingEngineIndex.create_tree_ah_index(\n", + " display_name=\"mm_rag_langchain_index\",\n", + " dimensions=DIMENSIONS,\n", + " approximate_neighbors_count=150,\n", + " leaf_node_embedding_count=500,\n", + " leaf_nodes_to_search_percent=7,\n", + " description=\"Multimodal RAG LangChain Index\",\n", + " index_update_method=\"STREAM_UPDATE\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "add71035aaa1" + }, + "source": [ + "- Create [`MatchingEngineIndexEndpoint`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndexEndpoint)\n", + " - https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "140c0142b90f" + }, + "outputs": [], + "source": [ + "DEPLOYED_INDEX_ID = \"mm_rag_langchain_index_endpoint\"\n", + "\n", + "index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n", + " display_name=DEPLOYED_INDEX_ID,\n", + " description=\"Multimodal RAG LangChain Index Endpoint\",\n", + " public_endpoint_enabled=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b6adda75cab6" + }, + "source": [ + "- Deploy Index to Index Endpoint\n", + " - NOTE: This will take a while to run.\n", + " - You can stop this cell after starting it instead of waiting for deployment.\n", + " - You can check the status at https://console.cloud.google.com/vertex-ai/matching-engine/indexes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4a02468a018b" + }, + "outputs": [], + "source": [ + "index_endpoint = index_endpoint.deploy_index(\n", + " index=index, deployed_index_id=\"mm_rag_langchain_deployed_index\"\n", + ")\n", + "index_endpoint.deployed_indexes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bd8475f61ef9" + }, + "source": [ + "## Create retriever & load documents" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "487ad4e4ccac" + }, + "source": [ + "- Create [`VectorSearchVectorStore`](https://api.python.langchain.com/en/latest/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.\n", + "- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e49355d04889" + }, + "outputs": [], + "source": [ + "# The vectorstore to use to index the summaries\n", + "vectorstore = VectorSearchVectorStore.from_components(\n", + " project_id=PROJECT_ID,\n", + " region=LOCATION,\n", + " gcs_bucket_name=GCS_BUCKET,\n", + " index_id=index.name,\n", + " endpoint_id=index_endpoint.name,\n", + " embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", + " stream_update=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "26ef209ff8ba" + }, + "source": [ + "- Alternatively, use Chroma for a local vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1b7f713b2607" + }, + "outputs": [], + "source": [ + "# vectorstore = Chroma(\n", + "# collection_name=\"mm_rag_test\",\n", + "# embedding_function=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67a4b0490b45" + }, + "source": [ + "- Create Multi-Vector Retriever using the vector store you created.\n", + "- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8e92ff890483" + }, + "outputs": [], + "source": [ + "docstore = InMemoryStore()\n", + "\n", + "id_key = \"doc_id\"\n", + "# Create the multi-vector retriever\n", + "retriever_multi_vector_img = MultiVectorRetriever(\n", + " vectorstore=vectorstore,\n", + " docstore=docstore,\n", + " id_key=id_key,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "96b37cf7dc47" + }, + "source": [ + "- Load data into Document Store and Vector Store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0a92a4b04319" + }, + "outputs": [], + "source": [ + "# Raw Document Contents\n", + "doc_contents = texts + tables + img_base64_list\n", + "\n", + "doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n", + "summary_docs = [\n", + " Document(page_content=s, metadata={id_key: doc_ids[i]})\n", + " for i, s in enumerate(text_summaries + table_summaries + image_summaries)\n", + "]\n", + "\n", + "retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, doc_contents)))\n", + "\n", + "# If using Vertex AI Vector Search, this will take a while to complete.\n", + "# You can cancel this cell and continue later.\n", + "retriever_multi_vector_img.vectorstore.add_documents(summary_docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b665ead18f3b" + }, + "source": [ + "## Create Chain with Retriever and Gemini LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5228d5831d34" + }, + "outputs": [], + "source": [ + "def looks_like_base64(sb):\n", + " \"\"\"Check if the string looks like base64\"\"\"\n", + " return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n", + "\n", + "\n", + "def is_image_data(b64data):\n", + " \"\"\"\n", + " Check if the base64 data is an image by looking at the start of the data\n", + " \"\"\"\n", + " image_signatures = {\n", + " b\"\\xFF\\xD8\\xFF\": \"jpg\",\n", + " b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n", + " b\"\\x47\\x49\\x46\\x38\": \"gif\",\n", + " b\"\\x52\\x49\\x46\\x46\": \"webp\",\n", + " }\n", + " try:\n", + " header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n", + " for sig, format in image_signatures.items():\n", + " if header.startswith(sig):\n", + " return True\n", + " return False\n", + " except Exception:\n", + " return False\n", + "\n", + "\n", + "def split_image_text_types(docs):\n", + " \"\"\"\n", + " Split base64-encoded images and texts\n", + " \"\"\"\n", + " b64_images = []\n", + " texts = []\n", + " for doc in docs:\n", + " # Check if the document is of type Document and extract page_content if so\n", + " if isinstance(doc, Document):\n", + " doc = doc.page_content\n", + " if looks_like_base64(doc) and is_image_data(doc):\n", + " b64_images.append(doc)\n", + " else:\n", + " texts.append(doc)\n", + " return {\"images\": b64_images, \"texts\": texts}\n", + "\n", + "\n", + "def img_prompt_func(data_dict):\n", + " \"\"\"\n", + " Join the context into a single string\n", + " \"\"\"\n", + " formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n", + " messages = [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": (\n", + " \"You are financial analyst tasking with providing investment advice.\\n\"\n", + " \"You will be given a mix of text, tables, and image(s) usually of charts or graphs.\\n\"\n", + " \"Use this information to provide investment advice related to the user's question. \\n\"\n", + " f\"User-provided question: {data_dict['question']}\\n\\n\"\n", + " \"Text and / or tables:\\n\"\n", + " f\"{formatted_texts}\"\n", + " ),\n", + " }\n", + " ]\n", + "\n", + " # Adding image(s) to the messages if present\n", + " if data_dict[\"context\"][\"images\"]:\n", + " for image in data_dict[\"context\"][\"images\"]:\n", + " messages.append(\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n", + " }\n", + " )\n", + " return [HumanMessage(content=messages)]\n", + "\n", + "\n", + "# Create RAG chain\n", + "chain_multimodal_rag = (\n", + " {\n", + " \"context\": retriever_multi_vector_img | RunnableLambda(split_image_text_types),\n", + " \"question\": RunnablePassthrough(),\n", + " }\n", + " | RunnableLambda(img_prompt_func)\n", + " | ChatVertexAI(\n", + " temperature=0,\n", + " model_name=MODEL_NAME,\n", + " max_output_tokens=TOKEN_LIMIT,\n", + " ) # Multi-modal LLM\n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2925d397fdbb" + }, + "source": [ + "## Process user query" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b3b445f934a8" + }, + "outputs": [], + "source": [ + "query = \"\"\"\n", + " - What are the critical difference between various graphs for Class A Share?\n", + " - Which index best matches Class A share performance closely where Google is not already a part? Explain the reasoning.\n", + " - Identify key chart patterns for Google Class A shares.\n", + " - What is cost of revenues, operating expenses and net income for 2020. Do mention the percentage change\n", + " - What was the effect of Covid in the 2020 financial year?\n", + " - What are the total revenues for APAC and USA for 2021?\n", + " - What is deferred income taxes?\n", + " - How do you compute net income per share?\n", + " - What drove percentage change in the consolidated revenue and cost of revenue for the year 2021 and was there any effect of Covid?\n", + " - What is the cause of 41% increase in revenue from 2020 to 2021 and how much is dollar change?\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6172a22b1203" + }, + "source": [ + "### Get Retrieved documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "90a34d3712e0" + }, + "outputs": [], + "source": [ + "# List of source documents\n", + "docs = retriever_multi_vector_img.get_relevant_documents(query, limit=10)\n", + "\n", + "source_docs = split_image_text_types(docs)\n", + "\n", + "print(source_docs[\"texts\"])\n", + "\n", + "for i in source_docs[\"images\"]:\n", + " display(Image(base64.b64decode(i)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bd784ce7f205" + }, + "source": [ + "### Get generative response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4c5f936f89da" + }, + "outputs": [], + "source": [ + "result = chain_multimodal_rag.invoke(query)\n", + "\n", + "Markdown(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KwNrHCqbi3xi" + }, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "05jynhZnkgxn" + }, + "source": [ + "Congratulations on making it through this multimodal RAG notebook!\n", + "\n", + "While multimodal RAG can be quite powerful, note that it can face some limitations:\n", + "\n", + "* **Data dependency:** Needs high-accuracy data from the text and visuals.\n", + "* **Computationally demanding:** Generating embeddings from multimodal data is resource-intensive.\n", + "* **Domain specific:** Models trained on general data may not shine in specialized fields like medicine.\n", + "* **Black box:** Understanding how these models work can be tricky, hindering trust and adoption.\n", + "\n", + "\n", + "Despite these challenges, multimodal RAG represents a significant step towards search and retrieval systems that can handle diverse, multimodal data." + ] + } + ], + "metadata": { + "colab": { + "name": "multimodal_rag_langchain.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } From 5aaa20ad53a2686d7cf45dd7ac8b100a092e0fb7 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Thu, 31 Oct 2024 13:29:07 -0500 Subject: [PATCH 2/2] Formatting/spelling --- .github/actions/spelling/allow.txt | 1 + .../multimodal_rag_langchain.ipynb | 2018 ++++++++--------- 2 files changed, 1008 insertions(+), 1011 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 37ed7e538d7..2d582babcf9 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -4,6 +4,7 @@ AGG AIP AMNOSH ANZ +APAC APIENTRY APSTUDIO AUVs diff --git a/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb b/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb index 6d951d184d3..b542ada1269 100644 --- a/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb +++ b/gemini/use-cases/retrieval-augmented-generation/multimodal_rag_langchain.ipynb @@ -1,1013 +1,1009 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ijGzTHJJUCPY" - }, - "outputs": [], - "source": [ - "# Copyright 2024 Google LLC\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NDsTUvKjwHBW" - }, - "source": [ - "# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \n", - " \"Google
Run in Colab Enterprise\n", - "
\n", - "
\n", - " \n", - " \"Google
Run in Colab\n", - "
\n", - "
\n", - " \n", - " \"Vertex
Open in Vertex AI Workbench\n", - "
\n", - "
\n", - " \n", - " \"GitHub
View on GitHub\n", - "
\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "37454218170d" - }, - "source": [ - "| | | \n", - "|-|-|\n", - "|Author(s) | [Holt Skinner](https://github.com/holtskinner) |" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VK1Q5ZYdVL4Y" - }, - "source": [ - "## Overview\n", - "\n", - "Retrieval augmented generation (RAG) has become a popular paradigm for enabling LLMs to access external data and also as a mechanism for grounding to mitigate against hallucinations.\n", - "\n", - "In this notebook, you will learn how to perform multimodal RAG where you will perform Q&A over a financial document filled with both text and images.\n", - "\n", - "### Gemini\n", - "\n", - "Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the Gemini 1.0 Pro Vision and Gemini 1.0 Pro models.\n", - "\n", - "### Comparing text-based and multimodal RAG\n", - "\n", - "Multimodal RAG offers several advantages over text-based RAG:\n", - "\n", - "1. **Enhanced knowledge access:** Multimodal RAG can access and process both textual and visual information, providing a richer and more comprehensive knowledge base for the LLM.\n", - "2. **Improved reasoning capabilities:** By incorporating visual cues, multimodal RAG can make better informed inferences across different types of data modalities.\n", - "\n", - "This notebook shows you how to use RAG with Vertex AI Gemini API, and [multimodal embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/multimodal-embeddings), to build a document search engine.\n", - "\n", - "Through hands-on examples, you will discover how to construct a multimedia-rich metadata repository of your document sources, enabling search, comparison, and reasoning across diverse information streams." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RQT500QqVPIb" - }, - "source": [ - "### Objectives\n", - "\n", - "This notebook provides a guide to building a document search engine using multimodal retrieval augmented generation (RAG), step by step:\n", - "\n", - "1. Extract and store metadata of documents containing both text and images, and generate embeddings the documents\n", - "2. Search the metadata with text queries to find similar text or images\n", - "3. Search the metadata with image queries to find similar images\n", - "4. Using a text query as input, search for contextual answers using both text and images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KnpYxfesh2rI" - }, - "source": [ - "### Costs\n", - "\n", - "This tutorial uses billable components of Google Cloud:\n", - "\n", - "- Vertex AI\n", - "\n", - "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DXJpXzKrh2rJ" - }, - "source": [ - "## Getting Started" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N5afkyDMSBW5" - }, - "source": [ - "### Install Vertex AI SDK for Python and other dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kc4WxYmLSBW5" - }, - "outputs": [], - "source": [ - "%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community \"unstructured[all-docs]\" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R5Xep4W9lq-Z" - }, - "source": [ - "### Restart current runtime\n", - "\n", - "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XRvKdaPDTznN" - }, - "outputs": [], - "source": [ - "# Restart kernel after installs so that your environment can access the new packages\n", - "import IPython\n", - "\n", - "app = IPython.Application.instance()\n", - "app.kernel.do_shutdown(True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SbmM4z7FOBpM" - }, - "source": [ - "
\n", - "⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FtsU9Bw9h2rL" - }, - "source": [ - "### Authenticate your notebook environment (Colab only)\n", - "\n", - "If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GpYEyLsOh2rL" - }, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "# Additional authentication is required for Google Colab\n", - "if \"google.colab\" in sys.modules:\n", - " # Authenticate user to Google Cloud\n", - " from google.colab import auth\n", - "\n", - " auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O1vKZZoEh2rL" - }, - "source": [ - "### Define Google Cloud project information" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gJqZ76rJh2rM" - }, - "outputs": [], - "source": [ - "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", - "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", - "\n", - "# For Vector Search Staging\n", - "GCS_BUCKET = \"YOUR_BUCKET_NAME\" # @param {type:\"string\"}\n", - "GCS_BUCKET_URI = f\"gs://{GCS_BUCKET}\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "57262621bd1c" - }, - "source": [ - "### Initialize the Vertex AI SDK" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D48gUW5-h2rM" - }, - "outputs": [], - "source": [ - "from google.cloud import aiplatform\n", - "\n", - "aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BuQwwRiniVFG" - }, - "source": [ - "### Import libraries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rtMowvm-yQ97" - }, - "outputs": [], - "source": [ - "import base64\n", - "import os\n", - "import re\n", - "import uuid\n", - "\n", - "from IPython.display import Image, Markdown, display\n", - "from langchain.prompts import PromptTemplate\n", - "from langchain.retrievers.multi_vector import MultiVectorRetriever\n", - "from langchain.storage import InMemoryStore\n", - "from langchain_core.documents import Document\n", - "from langchain_core.messages import AIMessage, HumanMessage\n", - "from langchain_core.output_parsers import StrOutputParser\n", - "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", - "from langchain_google_vertexai import (\n", - " ChatVertexAI,\n", - " VectorSearchVectorStore,\n", - " VertexAI,\n", - " VertexAIEmbeddings,\n", - ")\n", - "from langchain_text_splitters import CharacterTextSplitter\n", - "from unstructured.partition.pdf import partition_pdf\n", - "\n", - "# from langchain_community.vectorstores import Chroma # Optional" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2bf3ee5d1686" - }, - "source": [ - "### Define model information\n", - "\n", - "- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "eb39bdada39d" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"gemini-1.5-flash\"\n", - "GEMINI_OUTPUT_TOKEN_LIMIT = 8192\n", - "\n", - "EMBEDDING_MODEL_NAME = \"text-embedding-004\"\n", - "EMBEDDING_TOKEN_LIMIT = 2048\n", - "\n", - "TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2c919bd5a462" - }, - "source": [ - "## Data Loading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g7bKCQMFT7JT" - }, - "source": [ - "#### Get documents and images from GCS" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KwbL89zcY39N" - }, - "outputs": [], - "source": [ - "# Download documents and images used in this notebook\n", - "!gsutil -m rsync -r gs://github-repo/rag/intro_multimodal_rag/ .\n", - "print(\"Download completed\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ps1G-cCfpibN" - }, - "source": [ - "## Partition PDF tables, text, and images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jqLsy3iZ5t-R" - }, - "source": [ - "### The data\n", - "\n", - "The source data that you will use in this notebook is a modified version of [Google-10K](https://abc.xyz/assets/investor/static/pdf/20220202_alphabet_10K.pdf) which provides a comprehensive overview of the company's financial performance, business operations, management, and risk factors. As the original document is rather large, you will be using [a modified version with only 14 pages](https://storage.googleapis.com/github-repo/rag/multimodal_rag_langchain/google-10k-sample-14pages.pdf) instead. Although it's truncated, the sample document still contains text along with images such as tables, charts, and graphs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3a87cb1a097b" - }, - "outputs": [], - "source": [ - "pdf_folder_path = \"/content/data/\" if \"google.colab\" in sys.modules else \"data/\"\n", - "pdf_file_name = \"google-10k-sample-14pages.pdf\"\n", - "\n", - "# Extract images, tables, and chunk text from a PDF file.\n", - "raw_pdf_elements = partition_pdf(\n", - " filename=pdf_file_name,\n", - " extract_images_in_pdf=False,\n", - " infer_table_structure=True,\n", - " chunking_strategy=\"by_title\",\n", - " max_characters=4000,\n", - " new_after_n_chars=3800,\n", - " combine_text_under_n_chars=2000,\n", - " image_output_dir_path=pdf_folder_path,\n", - ")\n", - "\n", - "# Categorize extracted elements from a PDF into tables and texts.\n", - "tables = []\n", - "texts = []\n", - "for element in raw_pdf_elements:\n", - " if \"unstructured.documents.elements.Table\" in str(type(element)):\n", - " tables.append(str(element))\n", - " elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n", - " texts.append(str(element))\n", - "\n", - "# Optional: Enforce a specific token size for texts\n", - "text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n", - " chunk_size=10000, chunk_overlap=0\n", - ")\n", - "joined_texts = \" \".join(texts)\n", - "texts_4k_token = text_splitter.split_text(joined_texts)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "244963a30836" - }, - "outputs": [], - "source": [ - "# Generate summaries of text elements\n", - "\n", - "\n", - "def generate_text_summaries(\n", - " texts: list[str], tables: list[str], summarize_texts: bool = False\n", - ") -> tuple[list, list]:\n", - " \"\"\"\n", - " Summarize text elements\n", - " texts: List of str\n", - " tables: List of str\n", - " summarize_texts: Bool to summarize texts\n", - " \"\"\"\n", - "\n", - " # Prompt\n", - " prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n", - " These summaries will be embedded and used to retrieve the raw text or table elements. \\\n", - " Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n", - " prompt = PromptTemplate.from_template(prompt_text)\n", - " empty_response = RunnableLambda(\n", - " lambda x: AIMessage(content=\"Error processing document\")\n", - " )\n", - " # Text summary chain\n", - " model = VertexAI(\n", - " temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT\n", - " ).with_fallbacks([empty_response])\n", - " summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n", - "\n", - " # Initialize empty summaries\n", - " text_summaries = []\n", - " table_summaries = []\n", - "\n", - " # Apply to text if texts are provided and summarization is requested\n", - " if texts:\n", - " if summarize_texts:\n", - " text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n", - " else:\n", - " text_summaries = texts\n", - "\n", - " # Apply to tables if tables are provided\n", - " if tables:\n", - " table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n", - "\n", - " return text_summaries, table_summaries\n", - "\n", - "\n", - "# Get text, table summaries\n", - "text_summaries, table_summaries = generate_text_summaries(\n", - " texts_4k_token, tables, summarize_texts=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "379ae4ffbf83" - }, - "outputs": [], - "source": [ - "def encode_image(image_path: str) -> str:\n", - " \"\"\"Getting the base64 string\"\"\"\n", - " with open(image_path, \"rb\") as image_file:\n", - " return base64.b64encode(image_file.read()).decode(\"utf-8\")\n", - "\n", - "\n", - "def image_summarize(model: ChatVertexAI, base64_image: str, prompt: str) -> str:\n", - " \"\"\"Make image summary\"\"\"\n", - " msg = model.invoke(\n", - " [\n", - " HumanMessage(\n", - " content=[\n", - " {\"type\": \"text\", \"text\": prompt},\n", - " {\n", - " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": f\"data:image/png;base64,{base64_image}\"},\n", - " },\n", - " ]\n", - " )\n", - " ]\n", - " )\n", - " return msg.content\n", - "\n", - "\n", - "def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:\n", - " \"\"\"\n", - " Generate summaries and base64 encoded strings for images\n", - " path: Path to list of .jpg files extracted by Unstructured\n", - " \"\"\"\n", - "\n", - " # Store base64 encoded images\n", - " img_base64_list = []\n", - "\n", - " # Store image summaries\n", - " image_summaries = []\n", - "\n", - " # Prompt\n", - " prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n", - " These summaries will be embedded and used to retrieve the raw image. \\\n", - " Give a concise summary of the image that is well optimized for retrieval.\n", - " If it's a table, extract all elements of the table.\n", - " If it's a graph, explain the findings in the graph.\n", - " Do not include any numbers that are not mentioned in the image.\n", - " \"\"\"\n", - "\n", - " model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)\n", - "\n", - " # Apply to images\n", - " for img_file in sorted(os.listdir(path)):\n", - " if img_file.endswith(\".png\"):\n", - " base64_image = encode_image(os.path.join(path, img_file))\n", - " img_base64_list.append(base64_image)\n", - " image_summaries.append(image_summarize(model, base64_image, prompt))\n", - "\n", - " return img_base64_list, image_summaries\n", - "\n", - "\n", - "# Image summaries\n", - "img_base64_list, image_summaries = generate_img_summaries(\".\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b641a76265d0" - }, - "source": [ - "## Create & Deploy Vertex AI Vector Search Index & Endpoint\n", - "\n", - "Skip this step if you already have Vector Search set up.\n", - "\n", - "- https://console.cloud.google.com/vertex-ai/matching-engine/indexes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c15693534ed1" - }, - "source": [ - "- Create [`MatchingEngineIndex`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex)\n", - " - https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dad379accb68" - }, - "outputs": [], - "source": [ - "# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings\n", - "DIMENSIONS = 768 # Dimensions output from textembedding-gecko\n", - "\n", - "index = aiplatform.MatchingEngineIndex.create_tree_ah_index(\n", - " display_name=\"mm_rag_langchain_index\",\n", - " dimensions=DIMENSIONS,\n", - " approximate_neighbors_count=150,\n", - " leaf_node_embedding_count=500,\n", - " leaf_nodes_to_search_percent=7,\n", - " description=\"Multimodal RAG LangChain Index\",\n", - " index_update_method=\"STREAM_UPDATE\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "add71035aaa1" - }, - "source": [ - "- Create [`MatchingEngineIndexEndpoint`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndexEndpoint)\n", - " - https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "140c0142b90f" - }, - "outputs": [], - "source": [ - "DEPLOYED_INDEX_ID = \"mm_rag_langchain_index_endpoint\"\n", - "\n", - "index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n", - " display_name=DEPLOYED_INDEX_ID,\n", - " description=\"Multimodal RAG LangChain Index Endpoint\",\n", - " public_endpoint_enabled=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b6adda75cab6" - }, - "source": [ - "- Deploy Index to Index Endpoint\n", - " - NOTE: This will take a while to run.\n", - " - You can stop this cell after starting it instead of waiting for deployment.\n", - " - You can check the status at https://console.cloud.google.com/vertex-ai/matching-engine/indexes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4a02468a018b" - }, - "outputs": [], - "source": [ - "index_endpoint = index_endpoint.deploy_index(\n", - " index=index, deployed_index_id=\"mm_rag_langchain_deployed_index\"\n", - ")\n", - "index_endpoint.deployed_indexes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bd8475f61ef9" - }, - "source": [ - "## Create retriever & load documents" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "487ad4e4ccac" - }, - "source": [ - "- Create [`VectorSearchVectorStore`](https://api.python.langchain.com/en/latest/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.\n", - "- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e49355d04889" - }, - "outputs": [], - "source": [ - "# The vectorstore to use to index the summaries\n", - "vectorstore = VectorSearchVectorStore.from_components(\n", - " project_id=PROJECT_ID,\n", - " region=LOCATION,\n", - " gcs_bucket_name=GCS_BUCKET,\n", - " index_id=index.name,\n", - " endpoint_id=index_endpoint.name,\n", - " embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", - " stream_update=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "26ef209ff8ba" - }, - "source": [ - "- Alternatively, use Chroma for a local vector store." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1b7f713b2607" - }, - "outputs": [], - "source": [ - "# vectorstore = Chroma(\n", - "# collection_name=\"mm_rag_test\",\n", - "# embedding_function=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", - "# )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "67a4b0490b45" - }, - "source": [ - "- Create Multi-Vector Retriever using the vector store you created.\n", - "- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8e92ff890483" - }, - "outputs": [], - "source": [ - "docstore = InMemoryStore()\n", - "\n", - "id_key = \"doc_id\"\n", - "# Create the multi-vector retriever\n", - "retriever_multi_vector_img = MultiVectorRetriever(\n", - " vectorstore=vectorstore,\n", - " docstore=docstore,\n", - " id_key=id_key,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "96b37cf7dc47" - }, - "source": [ - "- Load data into Document Store and Vector Store" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0a92a4b04319" - }, - "outputs": [], - "source": [ - "# Raw Document Contents\n", - "doc_contents = texts + tables + img_base64_list\n", - "\n", - "doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n", - "summary_docs = [\n", - " Document(page_content=s, metadata={id_key: doc_ids[i]})\n", - " for i, s in enumerate(text_summaries + table_summaries + image_summaries)\n", - "]\n", - "\n", - "retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, doc_contents)))\n", - "\n", - "# If using Vertex AI Vector Search, this will take a while to complete.\n", - "# You can cancel this cell and continue later.\n", - "retriever_multi_vector_img.vectorstore.add_documents(summary_docs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b665ead18f3b" - }, - "source": [ - "## Create Chain with Retriever and Gemini LLM" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5228d5831d34" - }, - "outputs": [], - "source": [ - "def looks_like_base64(sb):\n", - " \"\"\"Check if the string looks like base64\"\"\"\n", - " return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n", - "\n", - "\n", - "def is_image_data(b64data):\n", - " \"\"\"\n", - " Check if the base64 data is an image by looking at the start of the data\n", - " \"\"\"\n", - " image_signatures = {\n", - " b\"\\xFF\\xD8\\xFF\": \"jpg\",\n", - " b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n", - " b\"\\x47\\x49\\x46\\x38\": \"gif\",\n", - " b\"\\x52\\x49\\x46\\x46\": \"webp\",\n", - " }\n", - " try:\n", - " header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n", - " for sig, format in image_signatures.items():\n", - " if header.startswith(sig):\n", - " return True\n", - " return False\n", - " except Exception:\n", - " return False\n", - "\n", - "\n", - "def split_image_text_types(docs):\n", - " \"\"\"\n", - " Split base64-encoded images and texts\n", - " \"\"\"\n", - " b64_images = []\n", - " texts = []\n", - " for doc in docs:\n", - " # Check if the document is of type Document and extract page_content if so\n", - " if isinstance(doc, Document):\n", - " doc = doc.page_content\n", - " if looks_like_base64(doc) and is_image_data(doc):\n", - " b64_images.append(doc)\n", - " else:\n", - " texts.append(doc)\n", - " return {\"images\": b64_images, \"texts\": texts}\n", - "\n", - "\n", - "def img_prompt_func(data_dict):\n", - " \"\"\"\n", - " Join the context into a single string\n", - " \"\"\"\n", - " formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n", - " messages = [\n", - " {\n", - " \"type\": \"text\",\n", - " \"text\": (\n", - " \"You are financial analyst tasking with providing investment advice.\\n\"\n", - " \"You will be given a mix of text, tables, and image(s) usually of charts or graphs.\\n\"\n", - " \"Use this information to provide investment advice related to the user's question. \\n\"\n", - " f\"User-provided question: {data_dict['question']}\\n\\n\"\n", - " \"Text and / or tables:\\n\"\n", - " f\"{formatted_texts}\"\n", - " ),\n", - " }\n", - " ]\n", - "\n", - " # Adding image(s) to the messages if present\n", - " if data_dict[\"context\"][\"images\"]:\n", - " for image in data_dict[\"context\"][\"images\"]:\n", - " messages.append(\n", - " {\n", - " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n", - " }\n", - " )\n", - " return [HumanMessage(content=messages)]\n", - "\n", - "\n", - "# Create RAG chain\n", - "chain_multimodal_rag = (\n", - " {\n", - " \"context\": retriever_multi_vector_img | RunnableLambda(split_image_text_types),\n", - " \"question\": RunnablePassthrough(),\n", - " }\n", - " | RunnableLambda(img_prompt_func)\n", - " | ChatVertexAI(\n", - " temperature=0,\n", - " model_name=MODEL_NAME,\n", - " max_output_tokens=TOKEN_LIMIT,\n", - " ) # Multi-modal LLM\n", - " | StrOutputParser()\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2925d397fdbb" - }, - "source": [ - "## Process user query" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "b3b445f934a8" - }, - "outputs": [], - "source": [ - "query = \"\"\"\n", - " - What are the critical difference between various graphs for Class A Share?\n", - " - Which index best matches Class A share performance closely where Google is not already a part? Explain the reasoning.\n", - " - Identify key chart patterns for Google Class A shares.\n", - " - What is cost of revenues, operating expenses and net income for 2020. Do mention the percentage change\n", - " - What was the effect of Covid in the 2020 financial year?\n", - " - What are the total revenues for APAC and USA for 2021?\n", - " - What is deferred income taxes?\n", - " - How do you compute net income per share?\n", - " - What drove percentage change in the consolidated revenue and cost of revenue for the year 2021 and was there any effect of Covid?\n", - " - What is the cause of 41% increase in revenue from 2020 to 2021 and how much is dollar change?\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6172a22b1203" - }, - "source": [ - "### Get Retrieved documents" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "90a34d3712e0" - }, - "outputs": [], - "source": [ - "# List of source documents\n", - "docs = retriever_multi_vector_img.get_relevant_documents(query, limit=10)\n", - "\n", - "source_docs = split_image_text_types(docs)\n", - "\n", - "print(source_docs[\"texts\"])\n", - "\n", - "for i in source_docs[\"images\"]:\n", - " display(Image(base64.b64decode(i)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bd784ce7f205" - }, - "source": [ - "### Get generative response" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4c5f936f89da" - }, - "outputs": [], - "source": [ - "result = chain_multimodal_rag.invoke(query)\n", - "\n", - "Markdown(result)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KwNrHCqbi3xi" - }, - "source": [ - "## Conclusions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "05jynhZnkgxn" - }, - "source": [ - "Congratulations on making it through this multimodal RAG notebook!\n", - "\n", - "While multimodal RAG can be quite powerful, note that it can face some limitations:\n", - "\n", - "* **Data dependency:** Needs high-accuracy data from the text and visuals.\n", - "* **Computationally demanding:** Generating embeddings from multimodal data is resource-intensive.\n", - "* **Domain specific:** Models trained on general data may not shine in specialized fields like medicine.\n", - "* **Black box:** Understanding how these models work can be tricky, hindering trust and adoption.\n", - "\n", - "\n", - "Despite these challenges, multimodal RAG represents a significant step towards search and retrieval systems that can handle diverse, multimodal data." - ] - } - ], - "metadata": { - "colab": { - "name": "multimodal_rag_langchain.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ijGzTHJJUCPY" + }, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDsTUvKjwHBW" + }, + "source": [ + "# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \"Google
Run in Colab Enterprise\n", + "
\n", + "
\n", + " \n", + " \"Google
Run in Colab\n", + "
\n", + "
\n", + " \n", + " \"Vertex
Open in Vertex AI Workbench\n", + "
\n", + "
\n", + " \n", + " \"GitHub
View on GitHub\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "37454218170d" + }, + "source": [ + "| | | \n", + "|-|-|\n", + "|Author(s) | [Holt Skinner](https://github.com/holtskinner) |" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VK1Q5ZYdVL4Y" + }, + "source": [ + "## Overview\n", + "\n", + "Retrieval augmented generation (RAG) has become a popular paradigm for enabling LLMs to access external data and also as a mechanism for grounding to mitigate against hallucinations.\n", + "\n", + "In this notebook, you will learn how to perform multimodal RAG where you will perform Q&A over a financial document filled with both text and images.\n", + "\n", + "### Gemini\n", + "\n", + "Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the Gemini 1.0 Pro Vision and Gemini 1.0 Pro models.\n", + "\n", + "### Comparing text-based and multimodal RAG\n", + "\n", + "Multimodal RAG offers several advantages over text-based RAG:\n", + "\n", + "1. **Enhanced knowledge access:** Multimodal RAG can access and process both textual and visual information, providing a richer and more comprehensive knowledge base for the LLM.\n", + "2. **Improved reasoning capabilities:** By incorporating visual cues, multimodal RAG can make better informed inferences across different types of data modalities.\n", + "\n", + "This notebook shows you how to use RAG with Vertex AI Gemini API, and [multimodal embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/multimodal-embeddings), to build a document search engine.\n", + "\n", + "Through hands-on examples, you will discover how to construct a multimedia-rich metadata repository of your document sources, enabling search, comparison, and reasoning across diverse information streams." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RQT500QqVPIb" + }, + "source": [ + "### Objectives\n", + "\n", + "This notebook provides a guide to building a document search engine using multimodal retrieval augmented generation (RAG), step by step:\n", + "\n", + "1. Extract and store metadata of documents containing both text and images, and generate embeddings the documents\n", + "2. Search the metadata with text queries to find similar text or images\n", + "3. Search the metadata with image queries to find similar images\n", + "4. Using a text query as input, search for contextual answers using both text and images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KnpYxfesh2rI" + }, + "source": [ + "### Costs\n", + "\n", + "This tutorial uses billable components of Google Cloud:\n", + "\n", + "- Vertex AI\n", + "\n", + "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DXJpXzKrh2rJ" + }, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N5afkyDMSBW5" + }, + "source": [ + "### Install Vertex AI SDK for Python and other dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kc4WxYmLSBW5" + }, + "outputs": [], + "source": [ + "%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community \"unstructured[all-docs]\" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R5Xep4W9lq-Z" + }, + "source": [ + "### Restart current runtime\n", + "\n", + "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XRvKdaPDTznN" + }, + "outputs": [], + "source": [ + "# Restart kernel after installs so that your environment can access the new packages\n", + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SbmM4z7FOBpM" + }, + "source": [ + "
\n", + "⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FtsU9Bw9h2rL" + }, + "source": [ + "### Authenticate your notebook environment (Colab only)\n", + "\n", + "If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GpYEyLsOh2rL" + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# Additional authentication is required for Google Colab\n", + "if \"google.colab\" in sys.modules:\n", + " # Authenticate user to Google Cloud\n", + " from google.colab import auth\n", + "\n", + " auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O1vKZZoEh2rL" + }, + "source": [ + "### Define Google Cloud project information" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gJqZ76rJh2rM" + }, + "outputs": [], + "source": [ + "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", + "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", + "\n", + "# For Vector Search Staging\n", + "GCS_BUCKET = \"YOUR_BUCKET_NAME\" # @param {type:\"string\"}\n", + "GCS_BUCKET_URI = f\"gs://{GCS_BUCKET}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "57262621bd1c" + }, + "source": [ + "### Initialize the Vertex AI SDK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D48gUW5-h2rM" + }, + "outputs": [], + "source": [ + "from google.cloud import aiplatform\n", + "\n", + "aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BuQwwRiniVFG" + }, + "source": [ + "### Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rtMowvm-yQ97" + }, + "outputs": [], + "source": [ + "import base64\n", + "import os\n", + "import re\n", + "import uuid\n", + "\n", + "from IPython.display import Image, Markdown, display\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.retrievers.multi_vector import MultiVectorRetriever\n", + "from langchain.storage import InMemoryStore\n", + "from langchain_core.documents import Document\n", + "from langchain_core.messages import AIMessage, HumanMessage\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", + "from langchain_google_vertexai import (\n", + " ChatVertexAI,\n", + " VectorSearchVectorStore,\n", + " VertexAI,\n", + " VertexAIEmbeddings,\n", + ")\n", + "from langchain_text_splitters import CharacterTextSplitter\n", + "from unstructured.partition.pdf import partition_pdf\n", + "\n", + "# from langchain_community.vectorstores import Chroma # Optional" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2bf3ee5d1686" + }, + "source": [ + "### Define model information\n", + "\n", + "- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "eb39bdada39d" + }, + "outputs": [], + "source": [ + "MODEL_NAME = \"gemini-1.5-flash\"\n", + "GEMINI_OUTPUT_TOKEN_LIMIT = 8192\n", + "\n", + "EMBEDDING_MODEL_NAME = \"text-embedding-004\"\n", + "EMBEDDING_TOKEN_LIMIT = 2048\n", + "\n", + "TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2c919bd5a462" + }, + "source": [ + "## Data Loading" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g7bKCQMFT7JT" + }, + "source": [ + "#### Get documents and images from GCS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KwbL89zcY39N" + }, + "outputs": [], + "source": [ + "# Download documents and images used in this notebook\n", + "!gsutil -m rsync -r gs://github-repo/rag/intro_multimodal_rag/ .\n", + "print(\"Download completed\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ps1G-cCfpibN" + }, + "source": [ + "## Partition PDF tables, text, and images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jqLsy3iZ5t-R" + }, + "source": [ + "### The data\n", + "\n", + "The source data that you will use in this notebook is a modified version of [Google-10K](https://abc.xyz/assets/investor/static/pdf/20220202_alphabet_10K.pdf) which provides a comprehensive overview of the company's financial performance, business operations, management, and risk factors. As the original document is rather large, you will be using [a modified version with only 14 pages](https://storage.googleapis.com/github-repo/rag/multimodal_rag_langchain/google-10k-sample-14pages.pdf) instead. Although it's truncated, the sample document still contains text along with images such as tables, charts, and graphs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3a87cb1a097b" + }, + "outputs": [], + "source": [ + "pdf_folder_path = \"/content/data/\" if \"google.colab\" in sys.modules else \"data/\"\n", + "pdf_file_name = \"google-10k-sample-14pages.pdf\"\n", + "\n", + "# Extract images, tables, and chunk text from a PDF file.\n", + "raw_pdf_elements = partition_pdf(\n", + " filename=pdf_file_name,\n", + " extract_images_in_pdf=False,\n", + " infer_table_structure=True,\n", + " chunking_strategy=\"by_title\",\n", + " max_characters=4000,\n", + " new_after_n_chars=3800,\n", + " combine_text_under_n_chars=2000,\n", + " image_output_dir_path=pdf_folder_path,\n", + ")\n", + "\n", + "# Categorize extracted elements from a PDF into tables and texts.\n", + "tables = []\n", + "texts = []\n", + "for element in raw_pdf_elements:\n", + " if \"unstructured.documents.elements.Table\" in str(type(element)):\n", + " tables.append(str(element))\n", + " elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n", + " texts.append(str(element))\n", + "\n", + "# Optional: Enforce a specific token size for texts\n", + "text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n", + " chunk_size=10000, chunk_overlap=0\n", + ")\n", + "joined_texts = \" \".join(texts)\n", + "texts_4k_token = text_splitter.split_text(joined_texts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "244963a30836" + }, + "outputs": [], + "source": [ + "# Generate summaries of text elements\n", + "\n", + "\n", + "def generate_text_summaries(\n", + " texts: list[str], tables: list[str], summarize_texts: bool = False\n", + ") -> tuple[list, list]:\n", + " \"\"\"\n", + " Summarize text elements\n", + " texts: List of str\n", + " tables: List of str\n", + " summarize_texts: Bool to summarize texts\n", + " \"\"\"\n", + "\n", + " # Prompt\n", + " prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw text or table elements. \\\n", + " Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n", + " prompt = PromptTemplate.from_template(prompt_text)\n", + " empty_response = RunnableLambda(\n", + " lambda x: AIMessage(content=\"Error processing document\")\n", + " )\n", + " # Text summary chain\n", + " model = VertexAI(\n", + " temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT\n", + " ).with_fallbacks([empty_response])\n", + " summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n", + "\n", + " # Initialize empty summaries\n", + " text_summaries = []\n", + " table_summaries = []\n", + "\n", + " # Apply to text if texts are provided and summarization is requested\n", + " if texts:\n", + " if summarize_texts:\n", + " text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n", + " else:\n", + " text_summaries = texts\n", + "\n", + " # Apply to tables if tables are provided\n", + " if tables:\n", + " table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n", + "\n", + " return text_summaries, table_summaries\n", + "\n", + "\n", + "# Get text, table summaries\n", + "text_summaries, table_summaries = generate_text_summaries(\n", + " texts_4k_token, tables, summarize_texts=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "379ae4ffbf83" + }, + "outputs": [], + "source": [ + "def encode_image(image_path: str) -> str:\n", + " \"\"\"Getting the base64 string\"\"\"\n", + " with open(image_path, \"rb\") as image_file:\n", + " return base64.b64encode(image_file.read()).decode(\"utf-8\")\n", + "\n", + "\n", + "def image_summarize(model: ChatVertexAI, base64_image: str, prompt: str) -> str:\n", + " \"\"\"Make image summary\"\"\"\n", + " msg = model.invoke(\n", + " [\n", + " HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": prompt},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/png;base64,{base64_image}\"},\n", + " },\n", + " ]\n", + " )\n", + " ]\n", + " )\n", + " return msg.content\n", + "\n", + "\n", + "def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:\n", + " \"\"\"\n", + " Generate summaries and base64 encoded strings for images\n", + " path: Path to list of .jpg files extracted by Unstructured\n", + " \"\"\"\n", + "\n", + " # Store base64 encoded images\n", + " img_base64_list = []\n", + "\n", + " # Store image summaries\n", + " image_summaries = []\n", + "\n", + " # Prompt\n", + " prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n", + " These summaries will be embedded and used to retrieve the raw image. \\\n", + " Give a concise summary of the image that is well optimized for retrieval.\n", + " If it's a table, extract all elements of the table.\n", + " If it's a graph, explain the findings in the graph.\n", + " Do not include any numbers that are not mentioned in the image.\n", + " \"\"\"\n", + "\n", + " model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)\n", + "\n", + " # Apply to images\n", + " for img_file in sorted(os.listdir(path)):\n", + " if img_file.endswith(\".png\"):\n", + " base64_image = encode_image(os.path.join(path, img_file))\n", + " img_base64_list.append(base64_image)\n", + " image_summaries.append(image_summarize(model, base64_image, prompt))\n", + "\n", + " return img_base64_list, image_summaries\n", + "\n", + "\n", + "# Image summaries\n", + "img_base64_list, image_summaries = generate_img_summaries(\".\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b641a76265d0" + }, + "source": [ + "## Create & Deploy Vertex AI Vector Search Index & Endpoint\n", + "\n", + "Skip this step if you already have Vector Search set up.\n", + "\n", + "- https://console.cloud.google.com/vertex-ai/matching-engine/indexes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c15693534ed1" + }, + "source": [ + "- Create [`MatchingEngineIndex`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex)\n", + " - https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dad379accb68" + }, + "outputs": [], + "source": [ + "# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings\n", + "DIMENSIONS = 768 # Dimensions output from textembedding-gecko\n", + "\n", + "index = aiplatform.MatchingEngineIndex.create_tree_ah_index(\n", + " display_name=\"mm_rag_langchain_index\",\n", + " dimensions=DIMENSIONS,\n", + " approximate_neighbors_count=150,\n", + " leaf_node_embedding_count=500,\n", + " leaf_nodes_to_search_percent=7,\n", + " description=\"Multimodal RAG LangChain Index\",\n", + " index_update_method=\"STREAM_UPDATE\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "add71035aaa1" + }, + "source": [ + "- Create [`MatchingEngineIndexEndpoint`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndexEndpoint)\n", + " - https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "140c0142b90f" + }, + "outputs": [], + "source": [ + "DEPLOYED_INDEX_ID = \"mm_rag_langchain_index_endpoint\"\n", + "\n", + "index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n", + " display_name=DEPLOYED_INDEX_ID,\n", + " description=\"Multimodal RAG LangChain Index Endpoint\",\n", + " public_endpoint_enabled=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b6adda75cab6" + }, + "source": [ + "- Deploy Index to Index Endpoint\n", + " - NOTE: This will take a while to run.\n", + " - You can stop this cell after starting it instead of waiting for deployment.\n", + " - You can check the status at https://console.cloud.google.com/vertex-ai/matching-engine/indexes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4a02468a018b" + }, + "outputs": [], + "source": [ + "index_endpoint = index_endpoint.deploy_index(\n", + " index=index, deployed_index_id=\"mm_rag_langchain_deployed_index\"\n", + ")\n", + "index_endpoint.deployed_indexes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bd8475f61ef9" + }, + "source": [ + "## Create retriever & load documents" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "487ad4e4ccac" + }, + "source": [ + "- Create [`VectorSearchVectorStore`](https://api.python.langchain.com/en/latest/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.\n", + "- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e49355d04889" + }, + "outputs": [], + "source": [ + "# The vectorstore to use to index the summaries\n", + "vectorstore = VectorSearchVectorStore.from_components(\n", + " project_id=PROJECT_ID,\n", + " region=LOCATION,\n", + " gcs_bucket_name=GCS_BUCKET,\n", + " index_id=index.name,\n", + " endpoint_id=index_endpoint.name,\n", + " embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", + " stream_update=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "26ef209ff8ba" + }, + "source": [ + "- Alternatively, use Chroma for a local vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1b7f713b2607" + }, + "outputs": [], + "source": [ + "# vectorstore = Chroma(\n", + "# collection_name=\"mm_rag_test\",\n", + "# embedding_function=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67a4b0490b45" + }, + "source": [ + "- Create Multi-Vector Retriever using the vector store you created.\n", + "- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8e92ff890483" + }, + "outputs": [], + "source": [ + "docstore = InMemoryStore()\n", + "\n", + "id_key = \"doc_id\"\n", + "# Create the multi-vector retriever\n", + "retriever_multi_vector_img = MultiVectorRetriever(\n", + " vectorstore=vectorstore,\n", + " docstore=docstore,\n", + " id_key=id_key,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "96b37cf7dc47" + }, + "source": [ + "- Load data into Document Store and Vector Store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0a92a4b04319" + }, + "outputs": [], + "source": [ + "# Raw Document Contents\n", + "doc_contents = texts + tables + img_base64_list\n", + "\n", + "doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n", + "summary_docs = [\n", + " Document(page_content=s, metadata={id_key: doc_ids[i]})\n", + " for i, s in enumerate(text_summaries + table_summaries + image_summaries)\n", + "]\n", + "\n", + "retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, doc_contents)))\n", + "\n", + "# If using Vertex AI Vector Search, this will take a while to complete.\n", + "# You can cancel this cell and continue later.\n", + "retriever_multi_vector_img.vectorstore.add_documents(summary_docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b665ead18f3b" + }, + "source": [ + "## Create Chain with Retriever and Gemini LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5228d5831d34" + }, + "outputs": [], + "source": [ + "def looks_like_base64(sb):\n", + " \"\"\"Check if the string looks like base64\"\"\"\n", + " return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n", + "\n", + "\n", + "def is_image_data(b64data):\n", + " \"\"\"\n", + " Check if the base64 data is an image by looking at the start of the data\n", + " \"\"\"\n", + " image_signatures = {\n", + " b\"\\xFF\\xD8\\xFF\": \"jpg\",\n", + " b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n", + " b\"\\x47\\x49\\x46\\x38\": \"gif\",\n", + " b\"\\x52\\x49\\x46\\x46\": \"webp\",\n", + " }\n", + " try:\n", + " header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n", + " for sig, format in image_signatures.items():\n", + " if header.startswith(sig):\n", + " return True\n", + " return False\n", + " except Exception:\n", + " return False\n", + "\n", + "\n", + "def split_image_text_types(docs):\n", + " \"\"\"\n", + " Split base64-encoded images and texts\n", + " \"\"\"\n", + " b64_images = []\n", + " texts = []\n", + " for doc in docs:\n", + " # Check if the document is of type Document and extract page_content if so\n", + " if isinstance(doc, Document):\n", + " doc = doc.page_content\n", + " if looks_like_base64(doc) and is_image_data(doc):\n", + " b64_images.append(doc)\n", + " else:\n", + " texts.append(doc)\n", + " return {\"images\": b64_images, \"texts\": texts}\n", + "\n", + "\n", + "def img_prompt_func(data_dict):\n", + " \"\"\"\n", + " Join the context into a single string\n", + " \"\"\"\n", + " formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n", + " messages = [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": (\n", + " \"You are financial analyst tasking with providing investment advice.\\n\"\n", + " \"You will be given a mix of text, tables, and image(s) usually of charts or graphs.\\n\"\n", + " \"Use this information to provide investment advice related to the user's question. \\n\"\n", + " f\"User-provided question: {data_dict['question']}\\n\\n\"\n", + " \"Text and / or tables:\\n\"\n", + " f\"{formatted_texts}\"\n", + " ),\n", + " }\n", + " ]\n", + "\n", + " # Adding image(s) to the messages if present\n", + " if data_dict[\"context\"][\"images\"]:\n", + " for image in data_dict[\"context\"][\"images\"]:\n", + " messages.append(\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n", + " }\n", + " )\n", + " return [HumanMessage(content=messages)]\n", + "\n", + "\n", + "# Create RAG chain\n", + "chain_multimodal_rag = (\n", + " {\n", + " \"context\": retriever_multi_vector_img | RunnableLambda(split_image_text_types),\n", + " \"question\": RunnablePassthrough(),\n", + " }\n", + " | RunnableLambda(img_prompt_func)\n", + " | ChatVertexAI(\n", + " temperature=0,\n", + " model_name=MODEL_NAME,\n", + " max_output_tokens=TOKEN_LIMIT,\n", + " ) # Multi-modal LLM\n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2925d397fdbb" + }, + "source": [ + "## Process user query" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b3b445f934a8" + }, + "outputs": [], + "source": [ + "query = \"\"\"\n", + " - What are the critical difference between various graphs for Class A Share?\n", + " - Which index best matches Class A share performance closely where Google is not already a part? Explain the reasoning.\n", + " - Identify key chart patterns for Google Class A shares.\n", + " - What is cost of revenues, operating expenses and net income for 2020. Do mention the percentage change\n", + " - What was the effect of Covid in the 2020 financial year?\n", + " - What are the total revenues for APAC and USA for 2021?\n", + " - What is deferred income taxes?\n", + " - How do you compute net income per share?\n", + " - What drove percentage change in the consolidated revenue and cost of revenue for the year 2021 and was there any effect of Covid?\n", + " - What is the cause of 41% increase in revenue from 2020 to 2021 and how much is dollar change?\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6172a22b1203" + }, + "source": [ + "### Get Retrieved documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "90a34d3712e0" + }, + "outputs": [], + "source": [ + "# List of source documents\n", + "docs = retriever_multi_vector_img.get_relevant_documents(query, limit=10)\n", + "\n", + "source_docs = split_image_text_types(docs)\n", + "\n", + "print(source_docs[\"texts\"])\n", + "\n", + "for i in source_docs[\"images\"]:\n", + " display(Image(base64.b64decode(i)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bd784ce7f205" + }, + "source": [ + "### Get generative response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4c5f936f89da" + }, + "outputs": [], + "source": [ + "result = chain_multimodal_rag.invoke(query)\n", + "\n", + "Markdown(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KwNrHCqbi3xi" + }, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "05jynhZnkgxn" + }, + "source": [ + "Congratulations on making it through this multimodal RAG notebook!\n", + "\n", + "While multimodal RAG can be quite powerful, note that it can face some limitations:\n", + "\n", + "* **Data dependency:** Needs high-accuracy data from the text and visuals.\n", + "* **Computationally demanding:** Generating embeddings from multimodal data is resource-intensive.\n", + "* **Domain specific:** Models trained on general data may not shine in specialized fields like medicine.\n", + "* **Black box:** Understanding how these models work can be tricky, hindering trust and adoption.\n", + "\n", + "\n", + "Despite these challenges, multimodal RAG represents a significant step towards search and retrieval systems that can handle diverse, multimodal data." + ] + } + ], + "metadata": { + "colab": { + "name": "multimodal_rag_langchain.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }