diff --git a/CHANGELOG.md b/CHANGELOG.md index 091b147a8de5..b253c0d06aff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `loader.RagQueryLoader` with Remote Backend Example ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) - Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) diff --git a/docs/source/_figures/flowchart.svg b/docs/source/_figures/flowchart.svg new file mode 100644 index 000000000000..188d37b14f41 --- /dev/null +++ b/docs/source/_figures/flowchart.svg @@ -0,0 +1 @@ + diff --git a/docs/source/_figures/multihop_example.svg b/docs/source/_figures/multihop_example.svg new file mode 100644 index 000000000000..4925dcb9713d --- /dev/null +++ b/docs/source/_figures/multihop_example.svg @@ -0,0 +1 @@ + diff --git a/docs/source/_figures/remote_backend.svg b/docs/source/_figures/remote_backend.svg new file mode 100644 index 000000000000..c5791f0a95de --- /dev/null +++ b/docs/source/_figures/remote_backend.svg @@ -0,0 +1 @@ + diff --git a/docs/source/advanced/rag.rst b/docs/source/advanced/rag.rst new file mode 100644 index 000000000000..5473a44e5af3 --- /dev/null +++ b/docs/source/advanced/rag.rst @@ -0,0 +1,566 @@ +Working with LLM RAG in Pytorch Geometric +========================================= + +This series aims to provide a starting point and for +multi-step LLM Retrieval Augmented Generation +(RAG) using Graph Neural Networks. + +Motivation +---------- + +As Large Language Models (LLMs) quickly grow to dominate industry, they +are increasingly being deployed at scale in use cases that require very +specific contextual expertise. LLMs often struggle with these cases out +of the box, as they will hallucinate answers that are not included in +their training data. At the same time, many business already have large +graph databases full of important context that can provide important +domain-specific context to reduce hallucination and improve answer +fidelity for LLMs. Graph Neural Networks (GNNs) provide a means for +efficiently encoding this contextual information into the model, which +can help LLMs to better understand and generate answers. Hence, theres +an open research question as to how to effectively use GNN encodings +efficiently for this purpose, that the tooling provided here can help +investigate. + +Architecture +------------ + +To model the use-case of RAG from a large knowledge graph of millions of +nodes, we present the following architecture: + + + + + +.. figure:: ../_figures/flowchart.svg + :align: center + :width: 100% + + + +Graph RAG as shown in the diagram above follows the following order of +operations: + +0. To start, not pictured here, there must exist a large knowledge graph + that exists as a source of truth. The nodes and edges of this + knowledge graph + +During inference time, RAG implementations that follow this architecture +are composed of the following steps: + +1. Tokenize and encode the query using the LLM Encoder +2. Retrieve a subgraph of the larger knowledge graph (KG) relevant to + the query and encode it using a GNN +3. Jointly embed the GNN embedding with the LLM embedding +4. Utilize LLM Decoder to decode joint embedding and generate a response + + + + +Encoding a Large Knowledge Graph +================================ + +To start, a Large Knowledge Graph needs to be created from triplets or +multiple subgraphs in a dataset. + +Example 1: Building from Already Existing Datasets +-------------------------------------------------- + +In most RAG scenarios, the subset of the information corpus that gets +retrieved is crucial for whether the appropriate response to the LLM. +The same is true for GNN based RAG. For example, consider the +WebQSPDataset. + +.. code:: python + + from torch_geometric.datasets import WebQSPDataset + + num_questions = 100 + ds = WebQSPDataset('small_sample', limit=num_questions) + + +WebQSP is a dataset that is based off of a subset of the Freebase +Knowledge Graph, which is an open-source knowledge graph formerly +maintained by Google. For each question-answer pair in the dataset, a +subgraph was chosen based on a Semantic SPARQL search on the larger +knowledge graph, to provide relevent context on finding the answer. So +each entry in the dataset consists of: +- A question to be answered +- The answer +- A knowledge graph subgraph of Freebase that has the context +needed to answer the question. + +.. code:: python + + ds.raw_dataset + + >>> Dataset({ + features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'], + num_rows: 100 + }) + + + +.. code:: python + + ds.raw_dataset[0] + + + >>> {'id': 'WebQTrn-0', + 'question': 'what is the name of justin bieber brother', + 'answer': ['Jaxon Bieber'], + 'q_entity': ['Justin Bieber'], + 'a_entity': ['Jaxon Bieber'], + 'graph': [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'], + ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'], + ...], + 'choices': []} + + + +Although this dataset can be trained on as-is, a couple problems emerge +from doing so: +1. A retrieval algorithm needs to be implemented and +executed during inference time, that might not appropriately correspond +to the algorithm that was used to generate the dataset subgraphs. +1. The dataset as is not stored computationally efficiently, as there will +exist many duplicate nodes and edges that are shared between the +questions. + +As a result, it makes sense in this scenario to be able to encode all +the entries into a large knowledge graph, so that duplicate nodes and +edges can be avoided, and so that alternative retrieval algorithms can +be tried. We can do this with the LargeGraphIndexer class: + +.. code:: python + + from torch_geometric.data import LargeGraphIndexer, Data, get_features_for_triplets_groups + from torch_geometric.nn.nlp import SentenceTransformer + import time + import torch + import tqdm + from itertools import chain + import networkx as nx + +.. code:: python + + raw_dataset_graphs = [[tuple(trip) for trip in graph] for graph in ds.raw_dataset['graph']] + print(raw_dataset_graphs[0][:10]) + + >>> [('P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'), ('1Club.FM: Power', 'broadcast.content.artist', 'P!nk'), ...] + + +To show the benefits of this indexer in action, we will use the +following model to encode this sample of graphs using LargeGraphIndexer, +along with naively. + +.. code:: python + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SentenceTransformer(model_name='sentence-transformers/all-roberta-large-v1').to(device) + + +First, we compare the clock times of encoding using both methods. + +.. code:: python + + # Indexing question-by-question + dataset_graphs_embedded = [] + start = time.time() + for graph in tqdm.tqdm(raw_dataset_graphs): + nodes_map = dict() + edges_map = dict() + edge_idx_base = [] + + for src, edge, dst in graph: + # Collect nodes + if src not in nodes_map: + nodes_map[src] = len(nodes_map) + if dst not in nodes_map: + nodes_map[dst] = len(nodes_map) + + # Collect edge types + if edge not in edges_map: + edges_map[edge] = len(edges_map) + + # Record edge + edge_idx_base.append((nodes_map[src], edges_map[edge], nodes_map[dst])) + + # Encode nodes and edges + sorted_nodes = list(sorted(nodes_map.keys(), key=lambda x: nodes_map[x])) + sorted_edges = list(sorted(edges_map.keys(), key=lambda x: edges_map[x])) + + x = model.encode(sorted_nodes, batch_size=256) + edge_attrs_map = model.encode(sorted_edges, batch_size=256) + + edge_attrs = [] + edge_idx = [] + for trip in edge_idx_base: + edge_attrs.append(edge_attrs_map[trip[1]]) + edge_idx.append([trip[0], trip[2]]) + + dataset_graphs_embedded.append(Data(x=x, edge_index=torch.tensor(edge_idx).T, edge_attr=torch.stack(edge_attrs, dim=0))) + + + print(time.time()-start) + + >>> 121.68579435348511 + + + +.. code:: python + + # Using LargeGraphIndexer to make one large knowledge graph + from torch_geometric.data.large_graph_indexer import EDGE_RELATION + + start = time.time() + all_triplets_together = chain.from_iterable(raw_dataset_graphs) + # Index as one large graph + print('Indexing...') + indexer = LargeGraphIndexer.from_triplets(all_triplets_together) + + # first the nodes + unique_nodes = indexer.get_unique_node_features() + node_encs = model.encode(unique_nodes, batch_size=256) + indexer.add_node_feature(new_feature_name='x', new_feature_vals=node_encs) + + # then the edges + unique_edges = indexer.get_unique_edge_features(feature_name=EDGE_RELATION) + edge_attr = model.encode(unique_edges, batch_size=256) + indexer.add_edge_feature(new_feature_name="edge_attr", new_feature_vals=edge_attr, map_from_feature=EDGE_RELATION) + + ckpt_time = time.time() + whole_knowledge_graph = indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr') + whole_graph_done = time.time() + print(f"Time to create whole knowledge_graph: {whole_graph_done-start}") + + # Compute this to make sure we're comparing like to like on final time printout + whole_graph_diff = whole_graph_done-ckpt_time + + # retrieve subgraphs + print('Retrieving Subgraphs...') + dataset_graphs_embedded_largegraphindexer = [graph for graph in tqdm.tqdm(get_features_for_triplets_groups(indexer=indexer, triplet_groups=raw_dataset_graphs), total=num_questions)] + print(time.time()-start-whole_graph_diff) + + >>> Indexing... + >>> Time to create whole knowledge_graph: 114.01080107688904 + >>> Retrieving Subgraphs... + >>> 114.66037964820862 + + +The large graph indexer allows us to compute the entire knowledge graph +from a series of samples, so that new retrieval methods can also be +tested on the entire graph. We will see this attempted in practice later +on. + +It’s worth noting that, although the times are relatively similar right +now, the speedup with largegraphindexer will be much higher as the size +of the knowledge graph grows. This is due to the speedup being a factor +of the number of unique nodes and edges in the graph. + + +We expect the two results to be functionally identical, with the +differences being due to floating point jitter. + +.. code:: python + + def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.8): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all(torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1) > thresh) + def _graphs_are_same(tensor1, tensor2): + return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T)) + return _sorted_tensors_are_close(ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same(ground_truth.edge_index, new_method.edge_index) + + + all_results_match = True + for old_graph, new_graph in tqdm.tqdm(zip(dataset_graphs_embedded, dataset_graphs_embedded_largegraphindexer), total=num_questions): + all_results_match &= results_are_close_enough(old_graph, new_graph) + all_results_match + + >>> True + + + +When scaled up to the entire dataset, we see a 2x speedup with indexing +this way on the WebQSP Dataset. + +Example 2: Building a new Dataset from Questions and an already-existing Knowledge Graph +---------------------------------------------------------------------------------------- + +Motivation +~~~~~~~~~~ + +One potential application of knowledge graph structural encodings is +capturing the relationships between different entities that are multiple +hops apart. This can be challenging for an LLM to recognize from +prepended graph information. Here’s a motivating example (credit to +@Rishi Puri): + + +.. figure:: ../_figures/multihop_example.svg + :align: center + :width: 100% + + + +In this example, the question can only be answered by reasoning about +the relationships between the entities in the knowledge graph. + +Building a Multi-Hop QA Dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To see an example of encoding a large knowledge graph starting from an +existing set of triplets, check out the multi-hop example in +`examples/llm_plus_gnn/multihop_rag`. + +Question: How do we extract a contextual subgraph for a given query? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The chosen retrieval algorithm is a critical component in the pipeline +for affecting RAG performance. In the next section, we will +demonstrate a naive method of retrieval for a large knowledge graph. + + +Retrieval Algorithms and Scaling Retrieval +========================================== + +Motivation +---------- + +When building a RAG Pipeline for inference, the retrieval component is +important for the following reasons: +1. A given algorithm for retrieving subgraph context can have a +marked effect on the hallucination rate of the responses in the model +2. A given retrieval algorithm needs to be able to scale to larger +graphs of millions of nodes and edges in order to be practical for production. + +In this section, we will explore how to construct a RAG retrieval +algorithm from a given subgraph, and conduct some experiments to +evaluate its runtime performance. + +We want to do so in-line with Pytorch Geometric’s in-house framework for +remote backends: + + +.. figure:: ../_figures/remote_2.png + :align: center + :width: 100% + + + +As seen here, the GraphStore is used to store the neighbor relations +between the nodes of the graph, whereas the FeatureStore is used to +store the node and edge features in the graph. + +Let’s start by loading in a knowledge graph dataset for the sake of our +experiment: + +.. code:: python + + from torch_geometric.data import LargeGraphIndexer + from torch_geometric.datasets import WebQSPDataset + from itertools import chain + + ds = WebQSPDataset(root='demo', limit=10) + +Let’s set up our set of questions and graph triplets: + +.. code:: python + + questions = ds.raw_dataset['question'] + questions + + >>> ['what is the name of justin bieber brother', + 'what character did natalie portman play in star wars', + 'what country is the grand bahama island in', + 'what kind of money to take to bahamas', + 'what character did john noble play in lord of the rings', + 'who does joakim noah play for', + 'where are the nfl redskins from', + 'where did saki live', + 'who did draco malloy end up marrying', + 'which countries border the us'] + + + ds.raw_dataset[:10]['graph'][0][:10] + + + >>> [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'], + ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'], + ['Somebody to Love', 'music.recording.contributions', 'm.0rqp4h0'], + ['Rudolph Valentino', 'freebase.valuenotation.is_reviewed', 'Place of birth'], + ['Ice Cube', 'broadcast.artist.content', '.977 The Hits Channel'], + ['Colbie Caillat', 'broadcast.artist.content', 'Hot Wired Radio'], + ['Stephen Melton', 'people.person.nationality', 'United States of America'], + ['Record producer', + 'music.performance_role.regular_performances', + 'm.012m1vf1'], + ['Justin Bieber', 'award.award_winner.awards_won', 'm.0yrkc0l'], + ['1.FM Top 40', 'broadcast.content.artist', 'Geri Halliwell']] + + + all_triplets = chain.from_iterable((row['graph'] for row in ds.raw_dataset)) + +With these questions and triplets, we want to: +1. Consolidate all the relations in these triplets into a Knowledge Graph +2. Create a FeatureStore that encodes all the nodes and edges in the knowledge graph +3. Create a GraphStore that encodes all the edge indices in the knowledge graph + +In order to create a remote backend, we need to define a FeatureStore +and GraphStore locally, as well as a method for initializing its state +from triplets. The code methods used in this tutorial can be found in +`examples/llm_plus_gnn`. + +.. code:: python + + from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet + from rag_construction_utils import create_remote_backend_from_triplets, RemoteGraphBackendLoader + + # We define this GraphStore to sample the neighbors of a node locally. + # Ideally for a real remote backend, this interface would be replaced with an API to a Graph DB, such as Neo4j. + from rag_graph_store import NeighborSamplingRAGGraphStore + + # We define this FeatureStore to encode the nodes and edges locally, and perform appoximate KNN when indexing. + # Ideally for a real remote backend, this interface would be replaced with an API to a vector DB, such as Pinecone. + from rag_feature_store import SentenceTransformerFeatureStore + +.. code:: python + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SentenceTransformer(model_name="sentence-transformers/all-roberta-large-v1").to(device) + + backend_loader: RemoteGraphBackendLoader = create_remote_backend_from_triplets( + triplets=all_triplets, # All the triplets to insert into the backend + node_embedding_model=model, # Embedding model to process triplets with + node_method_to_call="encode", # This method will encode the nodes/edges with 'model.encode' in this case. + path="backend", # Save path + pre_transform=preprocess_triplet, # Preprocessing function to apply to triplets before invoking embedding model. + node_method_kwargs={"batch_size": 256}, # Keyword arguments to pass to the node_method_to_call. + graph_db=NeighborSamplingRAGGraphStore, # Graph Store to use + feature_db=SentenceTransformerFeatureStore # Feature Store to use + ) + # This loader saves a copy of the processed data locally to be transformed into a graphstore and featurestore when load() is called. + feature_store, graph_store = backend_loader.load() + +Now that we have initialized our remote backends, we can now retrieve +from them using a Loader to query the backends, as shown in this +diagram: + + +.. figure:: ../_figures/remote_3.png + :align: center + :width: 100% + + + +.. code:: python + + from torch_geometric.loader import RAGQueryLoader + + query_loader = RAGQueryLoader( + data=(feature_store, graph_store), # Remote Rag Graph Store and Feature Store + # Arguments to pass into the seed node/edge retrieval methods for the FeatureStore. + # In this case, it's k for the KNN on the nodes and edges. + seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10}, + # Arguments to pass into the GraphStore's Neighbor sampling method. + # In this case, the GraphStore implements a NeighborLoader, so it takes the same arguments. + sampler_kwargs={"num_neighbors": [40]*3}, + # Arguments to pass into the FeatureStore's feature loading method. + loader_kwargs={}, + # An optional local transform that can be applied on the returned subgraph. + local_filter=None, + ) + +To make better sense of this loader’s arguments, let’s take a closer +look at the retrieval process for a remote backend: + + +.. figure:: ../_figures/remote_backend.svg + :align: center + :width: 100% + + + +As we see here, there are 3 important steps to any remote backend +procedure for graphs: +1. Retrieve the seed nodes and edges to begin our retrieval process from. +2. Traverse the graph neighborhood of the seed nodes/edges to gather local context. +3. Fetch the features associated with the subgraphs obtained from the traversal. + +We can see that our Query Loader construction allows us to specify +unique hyperparameters for each unique step in this retrieval. + +Now we can submit our queries to the remote backend to retrieve our +subgraphs: + +.. code:: python + + sub_graphs = [] + for q in tqdm.tqdm(questions): + sub_graphs.append(query_loader.query(q)) + + + sub_graphs[0] + + >>> Data(x=[2251, 1024], edge_index=[2, 7806], edge_attr=[7806, 1024], node_idx=[2251], edge_idx=[7806]) + + + +These subgraphs are now retrieved using a different retrieval method +when compared to the original WebQSP dataset. Can we compare the +properties of this method to the original WebQSPDataset’s retrieval +method? Let’s compare some basics properties of the subgraphs: + +.. code:: python + + def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e-(subg_e | gt_e)) + return (tp+tn)/num_edges + def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + + ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet) + num_edges = len(ds.indexer._edges) + + + for subg, ground_truth in tqdm.tqdm(zip((query_loader.query(q) for q in questions), ground_truth_graphs)): + print(f"Size: {len(subg.x)}, Ground Truth Size: {len(ground_truth.x)}, Accuracy: {check_retrieval_accuracy(subg, ground_truth, num_edges)}, Precision: {check_retrieval_precision(subg, ground_truth)}, Recall: {check_retrieval_recall(subg, ground_truth)}") + + >>> Size: 2193, Ground Truth Size: 1709, Accuracy: 0.6636780705203827, Precision: 0.22923807012918535, Recall: 0.1994037381034285 + >>> Size: 2682, Ground Truth Size: 1251, Accuracy: 0.7158736400576746, Precision: 0.10843513670738801, Recall: 0.22692963233503774 + >>> Size: 2087, Ground Truth Size: 1285, Accuracy: 0.7979813868134749, Precision: 0.0547879177377892, Recall: 0.15757855822550831 + >>> Size: 2975, Ground Truth Size: 1988, Accuracy: 0.6956088609254162, Precision: 0.14820555621795636, Recall: 0.21768826619964973 + >>> Size: 2594, Ground Truth Size: 633, Accuracy: 0.78849128326124, Precision: 0.04202616198163095, Recall: 0.2032301480484522 + >>> Size: 2462, Ground Truth Size: 1044, Accuracy: 0.7703499803381832, Precision: 0.07646643109540636, Recall: 0.19551861221539574 + >>> Size: 2011, Ground Truth Size: 1382, Accuracy: 0.7871804954777821, Precision: 0.10117783355860205, Recall: 0.13142713819914723 + >>> Size: 2011, Ground Truth Size: 1052, Accuracy: 0.802831301612269, Precision: 0.06452691407556001, Recall: 0.16702726092600606 + >>> Size: 2892, Ground Truth Size: 1012, Accuracy: 0.7276182985974571, Precision: 0.10108615156751419, Recall: 0.20860927152317882 + >>> Size: 1817, Ground Truth Size: 1978, Accuracy: 0.7530475815965395, Precision: 0.1677807486631016, Recall: 0.11696178937558248 + + + +Note that, since we’re only comparing the results of 10 graphs here, +this retrieval algorithm is not taking into account the full corpus of +nodes in the dataset. If you want to see a full example, look at +``rag_generate.py``, or ``rag_generate_multihop.py`` These examples +generate datasets for the entirety of the WebQSP dataset, or the +WikiData Multihop datasets that are discussed in Section 0. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1c67eeddeec6..02315461c6ec 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,6 +44,7 @@ In addition, it consists of easy-to-use mini-batch loaders for operating on many advanced/remote advanced/graphgym advanced/cpu_affinity + advanced/rag .. toctree:: :maxdepth: 1 diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..51043ae84a5c 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,7 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| Example | Description | +| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | +| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md new file mode 100644 index 000000000000..0500b4a7e5ce --- /dev/null +++ b/examples/llm/g_retriever_utils/README.md @@ -0,0 +1,8 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | +| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py new file mode 100644 index 000000000000..0f1c0e1b87ec --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_backend_utils.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Protocol, + Tuple, + Type, + runtime_checkable, +) + +import torch +from torch import Tensor +from torch.nn import Module + +from torch_geometric.data import ( + FeatureStore, + GraphStore, + LargeGraphIndexer, + TripletLike, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION +from torch_geometric.distributed import ( + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.typing import EdgeType, NodeType + +RemoteGraphBackend = Tuple[FeatureStore, GraphStore] + +# TODO: Make everything compatible with Hetero graphs aswell + + +# Adapted from LocalGraphStore +@runtime_checkable +class ConvertableGraphStore(Protocol): + @classmethod + def from_data( + cls, + edge_id: Tensor, + edge_index: Tensor, + num_nodes: int, + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_hetero_data( + cls, + edge_id_dict: Dict[EdgeType, Tensor], + edge_index_dict: Dict[EdgeType, Tensor], + num_nodes_dict: Dict[NodeType, int], + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> GraphStore: + ... + + +# Adapted from LocalFeatureStore +@runtime_checkable +class ConvertableFeatureStore(Protocol): + @classmethod + def from_data( + cls, + node_id: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + edge_id: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_hetero_data( + cls, + node_id_dict: Dict[NodeType, Tensor], + x_dict: Optional[Dict[NodeType, Tensor]] = None, + y_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> FeatureStore: + ... + + +class RemoteDataType(Enum): + DATA = auto() + PARTITION = auto() + + +@dataclass +class RemoteGraphBackendLoader: + """Utility class to load triplets into a RAG Backend.""" + path: str + datatype: RemoteDataType + graph_store_type: Type[ConvertableGraphStore] + feature_store_type: Type[ConvertableFeatureStore] + + def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: + if self.datatype == RemoteDataType.DATA: + data_obj = torch.load(self.path) + graph_store = self.graph_store_type.from_data( + edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, + num_nodes=data_obj.num_nodes) + feature_store = self.feature_store_type.from_data( + node_id=data_obj['node_id'], x=data_obj.x, + edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) + elif self.datatype == RemoteDataType.PARTITION: + if pid is None: + assert pid is not None, \ + "Partition ID must be defined for loading from a " \ + + "partitioned store." + graph_store = self.graph_store_type.from_partition(self.path, pid) + feature_store = self.feature_store_type.from_partition( + self.path, pid) + else: + raise NotImplementedError + return (feature_store, graph_store) + + +# TODO: make profilable +def create_remote_backend_from_triplets( + triplets: Iterable[TripletLike], node_embedding_model: Module, + edge_embedding_model: Module | None = None, + graph_db: Type[ConvertableGraphStore] = LocalGraphStore, + feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, + node_method_to_call: str = "forward", + edge_method_to_call: str | None = None, + pre_transform: Callable[[TripletLike], TripletLike] | None = None, + path: str = '', n_parts: int = 1, + node_method_kwargs: Optional[Dict[str, Any]] = None, + edge_method_kwargs: Optional[Dict[str, Any]] = None +) -> RemoteGraphBackendLoader: + """Utility function that can be used to create a RAG Backend from triplets. + + Args: + triplets (Iterable[TripletLike]): Triplets to load into the RAG + Backend. + node_embedding_model (Module): Model to embed nodes into a feature + space. + edge_embedding_model (Module | None, optional): Model to embed edges + into a feature space. Defaults to the node model. + graph_db (Type[ConvertableGraphStore], optional): GraphStore class to + use. Defaults to LocalGraphStore. + feature_db (Type[ConvertableFeatureStore], optional): FeatureStore + class to use. Defaults to LocalFeatureStore. + node_method_to_call (str, optional): method to call for embeddings on + the node model. Defaults to "forward". + edge_method_to_call (str | None, optional): method to call for + embeddings on the edge model. Defaults to the node method. + pre_transform (Callable[[TripletLike], TripletLike] | None, optional): + optional preprocessing function for triplets. Defaults to None. + path (str, optional): path to save resulting stores. Defaults to ''. + n_parts (int, optional): Number of partitons to store in. + Defaults to 1. + node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into node encoding method. Defaults to None. + edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into edge encoding method. Defaults to None. + + Returns: + RemoteGraphBackendLoader: Loader to load RAG backend from disk or + memory. + """ + # Will return attribute errors for missing attributes + if not issubclass(graph_db, ConvertableGraphStore): + getattr(graph_db, "from_data") + getattr(graph_db, "from_hetero_data") + getattr(graph_db, "from_partition") + elif not issubclass(feature_db, ConvertableFeatureStore): + getattr(feature_db, "from_data") + getattr(feature_db, "from_hetero_data") + getattr(feature_db, "from_partition") + + # Resolve callable methods + node_method_kwargs = node_method_kwargs \ + if node_method_kwargs is not None else dict() + + edge_embedding_model = edge_embedding_model \ + if edge_embedding_model is not None else node_embedding_model + edge_method_to_call = edge_method_to_call \ + if edge_method_to_call is not None else node_method_to_call + edge_method_kwargs = edge_method_kwargs \ + if edge_method_kwargs is not None else node_method_kwargs + + # These will return AttributeErrors if they don't exist + node_model = getattr(node_embedding_model, node_method_to_call) + edge_model = getattr(edge_embedding_model, edge_method_to_call) + + indexer = LargeGraphIndexer.from_triplets(triplets, + pre_transform=pre_transform) + + node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) + indexer.add_node_feature('x', node_feats) + + edge_feats = edge_model( + indexer.get_unique_edge_features(feature_name=EDGE_RELATION), + **edge_method_kwargs) + indexer.add_edge_feature(new_feature_name="edge_attr", + new_feature_vals=edge_feats, + map_from_feature=EDGE_RELATION) + + data = indexer.to_data(node_feature_name='x', + edge_feature_name='edge_attr') + + if n_parts == 1: + torch.save(data, path) + return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, + feature_db) + else: + partitioner = Partitioner(data=data, num_parts=n_parts, root=path) + partitioner.generate_partition() + return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, + graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py new file mode 100644 index 000000000000..e01e9e59bb88 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_feature_store.py @@ -0,0 +1,189 @@ +import gc +from collections.abc import Iterable, Iterator +from typing import Any, Dict, Optional, Type, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torchmetrics.functional import pairwise_cosine_similarity + +from torch_geometric.data import Data, HeteroData +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.nn.pool import ApproxMIPSKNNIndex +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +# NOTE: Only compatible with Homogeneous graphs for now +class KNNRAGFeatureStore(LocalFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.enc_model = enc_model(*args, **kwargs).to(self.device) + self.enc_model.eval() + self.model_kwargs = \ + model_kwargs if model_kwargs is not None else dict() + super().__init__() + + @property + def x(self) -> Tensor: + return self.get_tensor(group_name=None, attr_name='x') + + @property + def edge_attr(self) -> Tensor: + return self.get_tensor(group_name=(None, None), attr_name='edge_attr') + + def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: + result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) + topk = min(k_nodes, len(self.x)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: + result = next(self._retrieve_seed_edges_batch([query], k_edges)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + + prizes = pairwise_cosine_similarity(query_enc, + self.edge_attr.to(self.device)) + topk = min(k_edges, len(self.edge_attr)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + + if isinstance(sample, HeteroSamplerOutput): + raise NotImplementedError + + # NOTE: torch_geometric.loader.utils.filter_custom_store can be used + # here if it supported edge features + node_id = sample.node + edge_id = sample.edge + edge_index = torch.stack((sample.row, sample.col), dim=0) + x = self.x[node_id] + edge_attr = self.edge_attr[edge_id] + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + node_idx=node_id, edge_idx=edge_id) + + +# TODO: Refactor because composition >> inheritance + + +def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, + device: torch.device, batch_size: int = 2**20): + """Add new features to the existing KNN index in batches. + + Args: + knn_index (ApproxMIPSKNNIndex): Index to add features to. + emb (Tensor): Embeddings to add. + device (torch.device): Device to store in + batch_size (int, optional): Batch size to iterate by. + Defaults to 2**20, which equates to 4GB if working with + 1024 dim floats. + """ + for i in range(0, emb.size(0), batch_size): + if emb.size(0) - i >= batch_size: + emb_batch = emb[i:i + batch_size].to(device) + else: + emb_batch = emb[i:].to(device) + knn_index.add(emb_batch) + + +class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + # TODO: Add kwargs for approx KNN to parameters here. + super().__init__(enc_model, model_kwargs, *args, **kwargs) + self.node_knn_index = None + self.edge_knn_index = None + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.node_knn_index is None: + self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.node_knn_index, self.x, + self.device) + + output = self.node_knn_index.search(query_enc, k=k_nodes) + yield from output.index + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.edge_knn_index is None: + self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, + self.device) + + output = self.edge_knn_index.search(query_enc, k=k_edges) + yield from output.index + + +# TODO: These two classes should be refactored +class SentenceTransformerFeatureStore(KNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) + + +class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py new file mode 100644 index 000000000000..c6895b453b0c --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_generate.py @@ -0,0 +1,137 @@ +# %% +import argparse +from itertools import chain +from typing import Tuple + +import pandas as pd +import torch +import tqdm +from rag_backend_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore + +from torch_geometric.data import Data +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +# %% +parser = argparse.ArgumentParser(description="""Generate new WebQSP subgraphs +NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models.""" + ) +# TODO: Add more arguments for configuring rag params +parser.add_argument("--use_pcst", action="store_true") +parser.add_argument("--num_samples", type=int, default=4700) +parser.add_argument("--out_file", default="subg_results.pt") +args = parser.parse_args() + +# %% +ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True, + force_reload=True) + +# %% +triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset) + +# %% +questions = ds.raw_dataset['question'] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]: + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + desc = ( + textual_nodes.to_csv(index=False) + "\n" + + textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"])) + graph["desc"] = desc + return graph + + +transform = apply_retrieval_via_pcst \ + if args.use_pcst else apply_retrieval_with_text + +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5}, + seed_edges_kwargs={"k_edges": 5}, + sampler_kwargs={"num_neighbors": [50] * 2}, + local_filter=transform) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% +retrieval_stats = {"precision": [], "recall": [], "accuracy": []} +subgs = [] +node_len = [] +edge_len = [] +for subg in tqdm.tqdm(query_loader.query(q) for q in questions): + subgs.append(subg) + node_len.append(subg['x'].shape[0]) + edge_len.append(subg['edge_attr'].shape[0]) + +for i, subg in enumerate(subgs): + subg['question'] = questions[i] + subg['label'] = ds[i]['label'] + +pd.DataFrame.from_dict(retrieval_stats).to_csv( + args.out_file.split('.')[0] + '_metadata.csv') +torch.save(subgs, args.out_file) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py new file mode 100644 index 000000000000..48473f287233 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_graph_store.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import FeatureStore +from torch_geometric.distributed import LocalGraphStore +from torch_geometric.sampler import ( + HeteroSamplerOutput, + NeighborSampler, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.sampler.neighbor_sampler import NumNeighborsType +from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes + + +class NeighborSamplingRAGGraphStore(LocalGraphStore): + def __init__(self, feature_store: Optional[FeatureStore] = None, + num_neighbors: NumNeighborsType = [1], **kwargs): + self.feature_store = feature_store + self._num_neighbors = num_neighbors + self.sample_kwargs = kwargs + self._sampler_is_initialized = False + super().__init__() + + def _init_sampler(self): + if self.feature_store is None: + raise AttributeError("Feature store not registered yet.") + self.sampler = NeighborSampler(data=(self.feature_store, self), + num_neighbors=self._num_neighbors, + **self.sample_kwargs) + self._sampler_is_initialized = True + + def register_feature_store(self, feature_store: FeatureStore): + self.feature_store = feature_store + self._sampler_is_initialized = False + + def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: + ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) + self._sampler_is_initialized = False + return ret + + @property + def edge_index(self): + return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) + + def put_edge_index(self, edge_index: EdgeTensorType, *args, + **kwargs) -> bool: + ret = super().put_edge_index(edge_index, *args, **kwargs) + # HACK + self.edge_idx_args = args + self.edge_idx_kwargs = kwargs + self._sampler_is_initialized = False + return ret + + @property + def num_neighbors(self): + return self._num_neighbors + + @num_neighbors.setter + def num_neighbors(self, num_neighbors: NumNeighborsType): + self._num_neighbors = num_neighbors + if hasattr(self, 'sampler'): + self.sampler.num_neighbors = num_neighbors + + def sample_subgraph( + self, seed_nodes: InputNodes, seed_edges: InputEdges, + num_neighbors: Optional[NumNeighborsType] = None + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample the graph starting from the given nodes and edges using the + in-built NeighborSampler. + + Args: + seed_nodes (InputNodes): Seed nodes to start sampling from. + seed_edges (InputEdges): Seed edges to start sampling from. + num_neighbors (Optional[NumNeighborsType], optional): Parameters + to determine how many hops and number of neighbors per hop. + Defaults to None. + + Returns: + Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput + for the input. + """ + if not self._sampler_is_initialized: + self._init_sampler() + if num_neighbors is not None: + self.num_neighbors = num_neighbors + + # FIXME: Right now, only input nodes/edges as tensors are be supported + if not isinstance(seed_nodes, Tensor): + raise NotImplementedError + if not isinstance(seed_edges, Tensor): + raise NotImplementedError + device = seed_nodes.device + + # TODO: Call sample_from_edges for seed_edges + # Turning them into nodes for now. + seed_edges = self.edge_index.to(device).T[seed_edges.to( + device)].reshape(-1) + seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) + + seed_nodes = seed_nodes.unique().contiguous() + node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) + out = self.sampler.sample_from_nodes(node_sample_input) + + return out diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md new file mode 100644 index 000000000000..ff43b16a2c05 --- /dev/null +++ b/examples/llm/multihop_rag/README.md @@ -0,0 +1,9 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. | +| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. | +| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. | + +NOTE: Performance of GRetriever on this dataset has not been evaluated. diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh new file mode 100644 index 000000000000..3c1970d39440 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_download.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Wikidata5m + +wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz" +tar -xvf "wikidata5m_alias.tar.gz" +wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz" +gzip -d "wikidata5m_all_triplet.txt.gz" -f + +# 2Multihopqa +wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip" +unzip -o "data_ids_april7.zip" diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py new file mode 100644 index 000000000000..46052bdf1b15 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_preprocess.py @@ -0,0 +1,276 @@ +"""Example workflow for downloading and assembling a multihop QA dataset.""" + +import argparse +import json +from subprocess import call + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import LargeGraphIndexer + +# %% [markdown] +# # Encoding A Large Knowledge Graph Part 2 + +# %% [markdown] +# In this notebook, we will continue where we left off by building a new +# multi-hop QA dataset based on Wikidata. + +# %% [markdown] +# ## Example 2: Building a new Dataset from Questions and an already-existing +# Knowledge Graph + +# %% [markdown] +# ### Motivation + +# %% [markdown] +# One potential application of knowledge graph structural encodings is +# capturing the relationships between different entities that are multiple +# hops apart. This can be challenging for an LLM to recognize from prepended +# graph information. Here's a motivating example (credit to @Rishi Puri): + +# %% [markdown] +# In this example, the question can only be answered by reasoning about the +# relationships between the entities in the knowledge graph. + +# %% [markdown] +# ### Building a Multi-Hop QA Dataset + +# %% [markdown] +# To start, we need to download the raw data of a knowledge graph. +# In this case, we use WikiData5M +# ([Wang et al] +# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)). +# Here we download the raw triplets and their entity codes. Information about +# this dataset can be found +# [here](https://deepgraphlearning.github.io/project/wikidata5m). + +# %% [markdown] +# The following download contains the ID to plaintext mapping for all the +# entities and relations in the knowledge graph: + +rv = call("./multihop_download.sh") + +# %% [markdown] +# To start, we are going to preprocess the knowledge graph to substitute each +# of the entity/relation codes with their plaintext aliases. This makes it +# easier to use a pre-trained textual encoding model to create triplet +# embeddings, as such a model likely won't understand how to properly embed +# the entity codes. + +# %% + +# %% +parser = argparse.ArgumentParser(description="Preprocess wikidata5m") +parser.add_argument("--n_triplets", type=int, default=-1) +args = parser.parse_args() + +# %% +# Substitute entity codes with their aliases +# Picking the first alias for each entity (rather arbitrarily) +alias_map = {} +rel_alias_map = {} +for line in open('wikidata5m_entity.txt'): + parts = line.strip().split('\t') + entity_id = parts[0] + aliases = parts[1:] + alias_map[entity_id] = aliases[0] +for line in open('wikidata5m_relation.txt'): + parts = line.strip().split('\t') + relation_id = parts[0] + relation_name = parts[1] + rel_alias_map[relation_id] = relation_name + +# %% +full_graph = [] +missing_total = 0 +total = 0 +limit = None if args.n_triplets == -1 else args.n_triplets +i = 0 + +for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')): + if limit is not None and i >= limit: + break + src, rel, dst = line.strip().split('\t') + if src not in alias_map: + missing_total += 1 + if dst not in alias_map: + missing_total += 1 + if rel not in rel_alias_map: + missing_total += 1 + total += 3 + full_graph.append([ + alias_map.get(src, src), + rel_alias_map.get(rel, rel), + alias_map.get(dst, dst) + ]) + i += 1 +print(f"Missing aliases: {missing_total}/{total}") + +# %% [markdown] +# Now `full_graph` represents the knowledge graph triplets in +# understandable plaintext. + +# %% [markdown] +# Next, we need a set of multi-hop questions that the Knowledge Graph will +# provide us with context for. We utilize a subset of +# [HotPotQA](https://hotpotqa.github.io/) +# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called +# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop) +# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)), +# which includes a subgraph of entities that serve as the ground truth +# justification for answering each multi-hop question: + +# %% +with open('train.json') as f: + train_data = json.load(f) +train_df = pd.DataFrame(train_data) +train_df['split_type'] = 'train' + +with open('dev.json') as f: + dev_data = json.load(f) +dev_df = pd.DataFrame(dev_data) +dev_df['split_type'] = 'dev' + +with open('test.json') as f: + test_data = json.load(f) +test_df = pd.DataFrame(test_data) +test_df['split_type'] = 'test' + +df = pd.concat([train_df, dev_df, test_df]) + +# %% [markdown] +# Now we need to extract the subgraphs + +# %% +df['graph_size'] = df['evidences_id'].apply(lambda row: len(row)) + +# %% [markdown] +# (Optional) We take only questions where the evidence graph is greater than +# 0. (Note: this gets rid of the test set): + +# %% +# df = df[df['graph_size'] > 0] + +# %% +refined_df = df[[ + '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type', + 'graph_size' +]] + +# %% [markdown] +# Checkpoint: + +# %% +refined_df.to_csv('wikimultihopqa_refined.csv', index=False) + +# %% [markdown] +# Now we need to check that all the entities mentioned in the question/answer +# set are also present in the Wikidata graph: + +# %% +relation_map = {} +with open('wikidata5m_relation.txt') as f: + for line in tqdm.tqdm(f): + parts = line.strip().split('\t') + for i in range(1, len(parts)): + if parts[i] not in relation_map: + relation_map[parts[i]] = [] + relation_map[parts[i]].append(parts[0]) + +# %% +entity_set = set() +with open('wikidata5m_entity.txt') as f: + for line in tqdm.tqdm(f): + entity_set.add(line.strip().split('\t')[0]) + +# %% +missing_entities = set() +missing_entity_idx = set() +for i, row in enumerate(refined_df.itertuples()): + for trip in row.evidences_id: + entities = trip[0], trip[2] + for entity in entities: + if entity not in entity_set: + # print( + # f'The following entity was not found in the KG: {entity}' + # ) + missing_entities.add(entity) + missing_entity_idx.add(i) + +# %% [markdown] +# Right now, we drop the missing entity entries. Additional preprocessing can +# be done here to resolve the entity/relation collisions, but that is out of +# the scope for this notebook. + +# %% +# missing relations are ok, but missing entities cannot be mapped to +# plaintext, so they should be dropped. +refined_df.reset_index(inplace=True, drop=True) + +# %% +cleaned_df = refined_df.drop(missing_entity_idx) + +# %% [markdown] +# Now we save the resulting graph and questions/answers dataset: + +# %% +cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False) + +# %% + +# %% +torch.save(full_graph, 'wikimultihopqa_full_graph.pt') + +# %% [markdown] +# ### Question: How do we extract a contextual subgraph for a given query? + +# %% [markdown] +# The chosen retrieval algorithm is a critical component in the pipeline for +# affecting RAG performance. In the next section (1), we will demonstrate a +# naive method of retrieval for a large knowledge graph, and how to apply it +# to this dataset along with WebQSP. + +# %% [markdown] +# ### Preparing a Textualized Graph for LLM + +# %% [markdown] +# For now however, we need to prepare the graph data to be used as a plaintext +# prefix to the LLM. In order to do this, we want to prompt the LLM to use the +# unique nodes, and unique edge triplets of a given subgraph. In order to do +# this, we prepare a unique indexed node df and edge df for the knowledge +# graph now. This process occurs trivially with the LargeGraphIndexer: + +# %% + +# %% +indexer = LargeGraphIndexer.from_triplets(full_graph) + +# %% +# Node DF +textual_nodes = pd.DataFrame.from_dict( + {"node_attr": indexer.get_node_features()}) +textual_nodes["node_id"] = textual_nodes.index +textual_nodes = textual_nodes[["node_id", "node_attr"]] + +# %% [markdown] +# Notice how LargeGraphIndexer ensures that there are no duplicate indices: + +# %% +# Edge DF +textual_edges = pd.DataFrame(indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) +textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]] +textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]] + +# %% [markdown] +# Note: The edge table refers to each node by its index in the node table. +# We will see how this gets utilized later when indexing a subgraph. + +# %% [markdown] +# Now we can save the result + +# %% +textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False) +textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False) diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py new file mode 100644 index 000000000000..de93a9e75dd1 --- /dev/null +++ b/examples/llm/multihop_rag/rag_generate_multihop.py @@ -0,0 +1,88 @@ +# %% +import argparse +import sys +from typing import Tuple + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import Data +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +sys.path.append('..') + +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerApproxFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +parser = argparse.ArgumentParser( + description="Generate new multihop dataset for rag") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--num_samples", type=int) +args = parser.parse_args() + +# %% +triplets = torch.load('wikimultihopqa_full_graph.pt') + +# %% +df = pd.read_csv('wikimultihopqa_cleaned.csv') +questions = df['question'][:args.num_samples] +labels = df['answer'][:args.num_samples] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerApproxFeatureStore).load() + +# %% + +all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv') +all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv') + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +# %% +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, + seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": [40] * 3}, + local_filter=apply_retrieval_via_pcst) + +# %% +subgs = [] +for q, l in tqdm.tqdm(zip(questions, labels)): + subg = query_loader.query(q) + subg['question'] = q + subg['label'] = l + subgs.append(subg) + +torch.save(subgs, 'subg_results.pt') diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 266f498a113b..7e83c35befb6 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,6 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin +from .rag_loader import RAGQueryLoader __all__ = classes = [ 'DataLoader', @@ -50,6 +51,7 @@ 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', + 'RAGQueryLoader', ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py new file mode 100644 index 000000000000..33d6cf0e868e --- /dev/null +++ b/torch_geometric/loader/rag_loader.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from torch_geometric.data import Data, FeatureStore, HeteroData +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +class RAGFeatureStore(Protocol): + """Feature store for remote GNN RAG backend.""" + @abstractmethod + def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: + """Makes a comparison between the query and all the nodes to get all + the closest nodes. Return the indices of the nodes that are to be seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: + """Makes a comparison between the query and all the edges to get all + the closest nodes. Returns the edge indices that are to be the seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + """Combines sampled subgraph output with features in a Data object.""" + ... + + +class RAGGraphStore(Protocol): + """Graph store for remote GNN RAG backend.""" + @abstractmethod + def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, + **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample a subgraph using the seeded nodes and edges.""" + ... + + @abstractmethod + def register_feature_store(self, feature_store: FeatureStore): + """Register a feature store to be used with the sampler. Samplers need + info from the feature store in order to work properly on HeteroGraphs. + """ + ... + + +# TODO: Make compatible with Heterographs + + +class RAGQueryLoader: + def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None): + """Loader meant for making queries from a remote backend. + + Args: + data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore + and GraphStore to load from. Assumed to conform to the + protocols listed above. + local_filter (Optional[Callable[[Data, Any], Data]], optional): + Optional local transform to apply to data after retrieval. + Defaults to None. + seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters + to pass into process for fetching seed nodes. Defaults to None. + seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters + to pass into process for fetching seed edges. Defaults to None. + sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for sampling graph. Defaults to None. + loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for loading graph features. Defaults to None. + """ + fstore, gstore = data + self.feature_store = fstore + self.graph_store = gstore + self.graph_store.register_feature_store(self.feature_store) + self.local_filter = local_filter + self.seed_nodes_kwargs = seed_nodes_kwargs or {} + self.seed_edges_kwargs = seed_edges_kwargs or {} + self.sampler_kwargs = sampler_kwargs or {} + self.loader_kwargs = loader_kwargs or {} + + def query(self, query: Any) -> Data: + """Retrieve a subgraph associated with the query with all its feature + attributes. + """ + seed_nodes = self.feature_store.retrieve_seed_nodes( + query, **self.seed_nodes_kwargs) + seed_edges = self.feature_store.retrieve_seed_edges( + query, **self.seed_edges_kwargs) + + subgraph_sample = self.graph_store.sample_subgraph( + seed_nodes, seed_edges, **self.sampler_kwargs) + + data = self.feature_store.load_subgraph(sample=subgraph_sample, + **self.loader_kwargs) + + if self.local_filter: + data = self.local_filter(data, query) + return data