diff --git a/CHANGELOG.md b/CHANGELOG.md
index 091b147a8de5..b253c0d06aff 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
+- Added `loader.RagQueryLoader` with Remote Backend Example ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
diff --git a/docs/source/_figures/flowchart.svg b/docs/source/_figures/flowchart.svg
new file mode 100644
index 000000000000..188d37b14f41
--- /dev/null
+++ b/docs/source/_figures/flowchart.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/_figures/multihop_example.svg b/docs/source/_figures/multihop_example.svg
new file mode 100644
index 000000000000..4925dcb9713d
--- /dev/null
+++ b/docs/source/_figures/multihop_example.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/_figures/remote_backend.svg b/docs/source/_figures/remote_backend.svg
new file mode 100644
index 000000000000..c5791f0a95de
--- /dev/null
+++ b/docs/source/_figures/remote_backend.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/advanced/rag.rst b/docs/source/advanced/rag.rst
new file mode 100644
index 000000000000..5473a44e5af3
--- /dev/null
+++ b/docs/source/advanced/rag.rst
@@ -0,0 +1,566 @@
+Working with LLM RAG in Pytorch Geometric
+=========================================
+
+This series aims to provide a starting point and for
+multi-step LLM Retrieval Augmented Generation
+(RAG) using Graph Neural Networks.
+
+Motivation
+----------
+
+As Large Language Models (LLMs) quickly grow to dominate industry, they
+are increasingly being deployed at scale in use cases that require very
+specific contextual expertise. LLMs often struggle with these cases out
+of the box, as they will hallucinate answers that are not included in
+their training data. At the same time, many business already have large
+graph databases full of important context that can provide important
+domain-specific context to reduce hallucination and improve answer
+fidelity for LLMs. Graph Neural Networks (GNNs) provide a means for
+efficiently encoding this contextual information into the model, which
+can help LLMs to better understand and generate answers. Hence, theres
+an open research question as to how to effectively use GNN encodings
+efficiently for this purpose, that the tooling provided here can help
+investigate.
+
+Architecture
+------------
+
+To model the use-case of RAG from a large knowledge graph of millions of
+nodes, we present the following architecture:
+
+
+
+
+
+.. figure:: ../_figures/flowchart.svg
+ :align: center
+ :width: 100%
+
+
+
+Graph RAG as shown in the diagram above follows the following order of
+operations:
+
+0. To start, not pictured here, there must exist a large knowledge graph
+ that exists as a source of truth. The nodes and edges of this
+ knowledge graph
+
+During inference time, RAG implementations that follow this architecture
+are composed of the following steps:
+
+1. Tokenize and encode the query using the LLM Encoder
+2. Retrieve a subgraph of the larger knowledge graph (KG) relevant to
+ the query and encode it using a GNN
+3. Jointly embed the GNN embedding with the LLM embedding
+4. Utilize LLM Decoder to decode joint embedding and generate a response
+
+
+
+
+Encoding a Large Knowledge Graph
+================================
+
+To start, a Large Knowledge Graph needs to be created from triplets or
+multiple subgraphs in a dataset.
+
+Example 1: Building from Already Existing Datasets
+--------------------------------------------------
+
+In most RAG scenarios, the subset of the information corpus that gets
+retrieved is crucial for whether the appropriate response to the LLM.
+The same is true for GNN based RAG. For example, consider the
+WebQSPDataset.
+
+.. code:: python
+
+ from torch_geometric.datasets import WebQSPDataset
+
+ num_questions = 100
+ ds = WebQSPDataset('small_sample', limit=num_questions)
+
+
+WebQSP is a dataset that is based off of a subset of the Freebase
+Knowledge Graph, which is an open-source knowledge graph formerly
+maintained by Google. For each question-answer pair in the dataset, a
+subgraph was chosen based on a Semantic SPARQL search on the larger
+knowledge graph, to provide relevent context on finding the answer. So
+each entry in the dataset consists of:
+- A question to be answered
+- The answer
+- A knowledge graph subgraph of Freebase that has the context
+needed to answer the question.
+
+.. code:: python
+
+ ds.raw_dataset
+
+ >>> Dataset({
+ features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
+ num_rows: 100
+ })
+
+
+
+.. code:: python
+
+ ds.raw_dataset[0]
+
+
+ >>> {'id': 'WebQTrn-0',
+ 'question': 'what is the name of justin bieber brother',
+ 'answer': ['Jaxon Bieber'],
+ 'q_entity': ['Justin Bieber'],
+ 'a_entity': ['Jaxon Bieber'],
+ 'graph': [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'],
+ ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'],
+ ...],
+ 'choices': []}
+
+
+
+Although this dataset can be trained on as-is, a couple problems emerge
+from doing so:
+1. A retrieval algorithm needs to be implemented and
+executed during inference time, that might not appropriately correspond
+to the algorithm that was used to generate the dataset subgraphs.
+1. The dataset as is not stored computationally efficiently, as there will
+exist many duplicate nodes and edges that are shared between the
+questions.
+
+As a result, it makes sense in this scenario to be able to encode all
+the entries into a large knowledge graph, so that duplicate nodes and
+edges can be avoided, and so that alternative retrieval algorithms can
+be tried. We can do this with the LargeGraphIndexer class:
+
+.. code:: python
+
+ from torch_geometric.data import LargeGraphIndexer, Data, get_features_for_triplets_groups
+ from torch_geometric.nn.nlp import SentenceTransformer
+ import time
+ import torch
+ import tqdm
+ from itertools import chain
+ import networkx as nx
+
+.. code:: python
+
+ raw_dataset_graphs = [[tuple(trip) for trip in graph] for graph in ds.raw_dataset['graph']]
+ print(raw_dataset_graphs[0][:10])
+
+ >>> [('P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'), ('1Club.FM: Power', 'broadcast.content.artist', 'P!nk'), ...]
+
+
+To show the benefits of this indexer in action, we will use the
+following model to encode this sample of graphs using LargeGraphIndexer,
+along with naively.
+
+.. code:: python
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = SentenceTransformer(model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+
+First, we compare the clock times of encoding using both methods.
+
+.. code:: python
+
+ # Indexing question-by-question
+ dataset_graphs_embedded = []
+ start = time.time()
+ for graph in tqdm.tqdm(raw_dataset_graphs):
+ nodes_map = dict()
+ edges_map = dict()
+ edge_idx_base = []
+
+ for src, edge, dst in graph:
+ # Collect nodes
+ if src not in nodes_map:
+ nodes_map[src] = len(nodes_map)
+ if dst not in nodes_map:
+ nodes_map[dst] = len(nodes_map)
+
+ # Collect edge types
+ if edge not in edges_map:
+ edges_map[edge] = len(edges_map)
+
+ # Record edge
+ edge_idx_base.append((nodes_map[src], edges_map[edge], nodes_map[dst]))
+
+ # Encode nodes and edges
+ sorted_nodes = list(sorted(nodes_map.keys(), key=lambda x: nodes_map[x]))
+ sorted_edges = list(sorted(edges_map.keys(), key=lambda x: edges_map[x]))
+
+ x = model.encode(sorted_nodes, batch_size=256)
+ edge_attrs_map = model.encode(sorted_edges, batch_size=256)
+
+ edge_attrs = []
+ edge_idx = []
+ for trip in edge_idx_base:
+ edge_attrs.append(edge_attrs_map[trip[1]])
+ edge_idx.append([trip[0], trip[2]])
+
+ dataset_graphs_embedded.append(Data(x=x, edge_index=torch.tensor(edge_idx).T, edge_attr=torch.stack(edge_attrs, dim=0)))
+
+
+ print(time.time()-start)
+
+ >>> 121.68579435348511
+
+
+
+.. code:: python
+
+ # Using LargeGraphIndexer to make one large knowledge graph
+ from torch_geometric.data.large_graph_indexer import EDGE_RELATION
+
+ start = time.time()
+ all_triplets_together = chain.from_iterable(raw_dataset_graphs)
+ # Index as one large graph
+ print('Indexing...')
+ indexer = LargeGraphIndexer.from_triplets(all_triplets_together)
+
+ # first the nodes
+ unique_nodes = indexer.get_unique_node_features()
+ node_encs = model.encode(unique_nodes, batch_size=256)
+ indexer.add_node_feature(new_feature_name='x', new_feature_vals=node_encs)
+
+ # then the edges
+ unique_edges = indexer.get_unique_edge_features(feature_name=EDGE_RELATION)
+ edge_attr = model.encode(unique_edges, batch_size=256)
+ indexer.add_edge_feature(new_feature_name="edge_attr", new_feature_vals=edge_attr, map_from_feature=EDGE_RELATION)
+
+ ckpt_time = time.time()
+ whole_knowledge_graph = indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr')
+ whole_graph_done = time.time()
+ print(f"Time to create whole knowledge_graph: {whole_graph_done-start}")
+
+ # Compute this to make sure we're comparing like to like on final time printout
+ whole_graph_diff = whole_graph_done-ckpt_time
+
+ # retrieve subgraphs
+ print('Retrieving Subgraphs...')
+ dataset_graphs_embedded_largegraphindexer = [graph for graph in tqdm.tqdm(get_features_for_triplets_groups(indexer=indexer, triplet_groups=raw_dataset_graphs), total=num_questions)]
+ print(time.time()-start-whole_graph_diff)
+
+ >>> Indexing...
+ >>> Time to create whole knowledge_graph: 114.01080107688904
+ >>> Retrieving Subgraphs...
+ >>> 114.66037964820862
+
+
+The large graph indexer allows us to compute the entire knowledge graph
+from a series of samples, so that new retrieval methods can also be
+tested on the entire graph. We will see this attempted in practice later
+on.
+
+It’s worth noting that, although the times are relatively similar right
+now, the speedup with largegraphindexer will be much higher as the size
+of the knowledge graph grows. This is due to the speedup being a factor
+of the number of unique nodes and edges in the graph.
+
+
+We expect the two results to be functionally identical, with the
+differences being due to floating point jitter.
+
+.. code:: python
+
+ def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.8):
+ def _sorted_tensors_are_close(tensor1, tensor2):
+ return torch.all(torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1) > thresh)
+ def _graphs_are_same(tensor1, tensor2):
+ return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T))
+ return _sorted_tensors_are_close(ground_truth.x, new_method.x) \
+ and _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr) \
+ and _graphs_are_same(ground_truth.edge_index, new_method.edge_index)
+
+
+ all_results_match = True
+ for old_graph, new_graph in tqdm.tqdm(zip(dataset_graphs_embedded, dataset_graphs_embedded_largegraphindexer), total=num_questions):
+ all_results_match &= results_are_close_enough(old_graph, new_graph)
+ all_results_match
+
+ >>> True
+
+
+
+When scaled up to the entire dataset, we see a 2x speedup with indexing
+this way on the WebQSP Dataset.
+
+Example 2: Building a new Dataset from Questions and an already-existing Knowledge Graph
+----------------------------------------------------------------------------------------
+
+Motivation
+~~~~~~~~~~
+
+One potential application of knowledge graph structural encodings is
+capturing the relationships between different entities that are multiple
+hops apart. This can be challenging for an LLM to recognize from
+prepended graph information. Here’s a motivating example (credit to
+@Rishi Puri):
+
+
+.. figure:: ../_figures/multihop_example.svg
+ :align: center
+ :width: 100%
+
+
+
+In this example, the question can only be answered by reasoning about
+the relationships between the entities in the knowledge graph.
+
+Building a Multi-Hop QA Dataset
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+To see an example of encoding a large knowledge graph starting from an
+existing set of triplets, check out the multi-hop example in
+`examples/llm_plus_gnn/multihop_rag`.
+
+Question: How do we extract a contextual subgraph for a given query?
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The chosen retrieval algorithm is a critical component in the pipeline
+for affecting RAG performance. In the next section, we will
+demonstrate a naive method of retrieval for a large knowledge graph.
+
+
+Retrieval Algorithms and Scaling Retrieval
+==========================================
+
+Motivation
+----------
+
+When building a RAG Pipeline for inference, the retrieval component is
+important for the following reasons:
+1. A given algorithm for retrieving subgraph context can have a
+marked effect on the hallucination rate of the responses in the model
+2. A given retrieval algorithm needs to be able to scale to larger
+graphs of millions of nodes and edges in order to be practical for production.
+
+In this section, we will explore how to construct a RAG retrieval
+algorithm from a given subgraph, and conduct some experiments to
+evaluate its runtime performance.
+
+We want to do so in-line with Pytorch Geometric’s in-house framework for
+remote backends:
+
+
+.. figure:: ../_figures/remote_2.png
+ :align: center
+ :width: 100%
+
+
+
+As seen here, the GraphStore is used to store the neighbor relations
+between the nodes of the graph, whereas the FeatureStore is used to
+store the node and edge features in the graph.
+
+Let’s start by loading in a knowledge graph dataset for the sake of our
+experiment:
+
+.. code:: python
+
+ from torch_geometric.data import LargeGraphIndexer
+ from torch_geometric.datasets import WebQSPDataset
+ from itertools import chain
+
+ ds = WebQSPDataset(root='demo', limit=10)
+
+Let’s set up our set of questions and graph triplets:
+
+.. code:: python
+
+ questions = ds.raw_dataset['question']
+ questions
+
+ >>> ['what is the name of justin bieber brother',
+ 'what character did natalie portman play in star wars',
+ 'what country is the grand bahama island in',
+ 'what kind of money to take to bahamas',
+ 'what character did john noble play in lord of the rings',
+ 'who does joakim noah play for',
+ 'where are the nfl redskins from',
+ 'where did saki live',
+ 'who did draco malloy end up marrying',
+ 'which countries border the us']
+
+
+ ds.raw_dataset[:10]['graph'][0][:10]
+
+
+ >>> [['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'],
+ ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'],
+ ['Somebody to Love', 'music.recording.contributions', 'm.0rqp4h0'],
+ ['Rudolph Valentino', 'freebase.valuenotation.is_reviewed', 'Place of birth'],
+ ['Ice Cube', 'broadcast.artist.content', '.977 The Hits Channel'],
+ ['Colbie Caillat', 'broadcast.artist.content', 'Hot Wired Radio'],
+ ['Stephen Melton', 'people.person.nationality', 'United States of America'],
+ ['Record producer',
+ 'music.performance_role.regular_performances',
+ 'm.012m1vf1'],
+ ['Justin Bieber', 'award.award_winner.awards_won', 'm.0yrkc0l'],
+ ['1.FM Top 40', 'broadcast.content.artist', 'Geri Halliwell']]
+
+
+ all_triplets = chain.from_iterable((row['graph'] for row in ds.raw_dataset))
+
+With these questions and triplets, we want to:
+1. Consolidate all the relations in these triplets into a Knowledge Graph
+2. Create a FeatureStore that encodes all the nodes and edges in the knowledge graph
+3. Create a GraphStore that encodes all the edge indices in the knowledge graph
+
+In order to create a remote backend, we need to define a FeatureStore
+and GraphStore locally, as well as a method for initializing its state
+from triplets. The code methods used in this tutorial can be found in
+`examples/llm_plus_gnn`.
+
+.. code:: python
+
+ from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet
+ from rag_construction_utils import create_remote_backend_from_triplets, RemoteGraphBackendLoader
+
+ # We define this GraphStore to sample the neighbors of a node locally.
+ # Ideally for a real remote backend, this interface would be replaced with an API to a Graph DB, such as Neo4j.
+ from rag_graph_store import NeighborSamplingRAGGraphStore
+
+ # We define this FeatureStore to encode the nodes and edges locally, and perform appoximate KNN when indexing.
+ # Ideally for a real remote backend, this interface would be replaced with an API to a vector DB, such as Pinecone.
+ from rag_feature_store import SentenceTransformerFeatureStore
+
+.. code:: python
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = SentenceTransformer(model_name="sentence-transformers/all-roberta-large-v1").to(device)
+
+ backend_loader: RemoteGraphBackendLoader = create_remote_backend_from_triplets(
+ triplets=all_triplets, # All the triplets to insert into the backend
+ node_embedding_model=model, # Embedding model to process triplets with
+ node_method_to_call="encode", # This method will encode the nodes/edges with 'model.encode' in this case.
+ path="backend", # Save path
+ pre_transform=preprocess_triplet, # Preprocessing function to apply to triplets before invoking embedding model.
+ node_method_kwargs={"batch_size": 256}, # Keyword arguments to pass to the node_method_to_call.
+ graph_db=NeighborSamplingRAGGraphStore, # Graph Store to use
+ feature_db=SentenceTransformerFeatureStore # Feature Store to use
+ )
+ # This loader saves a copy of the processed data locally to be transformed into a graphstore and featurestore when load() is called.
+ feature_store, graph_store = backend_loader.load()
+
+Now that we have initialized our remote backends, we can now retrieve
+from them using a Loader to query the backends, as shown in this
+diagram:
+
+
+.. figure:: ../_figures/remote_3.png
+ :align: center
+ :width: 100%
+
+
+
+.. code:: python
+
+ from torch_geometric.loader import RAGQueryLoader
+
+ query_loader = RAGQueryLoader(
+ data=(feature_store, graph_store), # Remote Rag Graph Store and Feature Store
+ # Arguments to pass into the seed node/edge retrieval methods for the FeatureStore.
+ # In this case, it's k for the KNN on the nodes and edges.
+ seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10},
+ # Arguments to pass into the GraphStore's Neighbor sampling method.
+ # In this case, the GraphStore implements a NeighborLoader, so it takes the same arguments.
+ sampler_kwargs={"num_neighbors": [40]*3},
+ # Arguments to pass into the FeatureStore's feature loading method.
+ loader_kwargs={},
+ # An optional local transform that can be applied on the returned subgraph.
+ local_filter=None,
+ )
+
+To make better sense of this loader’s arguments, let’s take a closer
+look at the retrieval process for a remote backend:
+
+
+.. figure:: ../_figures/remote_backend.svg
+ :align: center
+ :width: 100%
+
+
+
+As we see here, there are 3 important steps to any remote backend
+procedure for graphs:
+1. Retrieve the seed nodes and edges to begin our retrieval process from.
+2. Traverse the graph neighborhood of the seed nodes/edges to gather local context.
+3. Fetch the features associated with the subgraphs obtained from the traversal.
+
+We can see that our Query Loader construction allows us to specify
+unique hyperparameters for each unique step in this retrieval.
+
+Now we can submit our queries to the remote backend to retrieve our
+subgraphs:
+
+.. code:: python
+
+ sub_graphs = []
+ for q in tqdm.tqdm(questions):
+ sub_graphs.append(query_loader.query(q))
+
+
+ sub_graphs[0]
+
+ >>> Data(x=[2251, 1024], edge_index=[2, 7806], edge_attr=[7806, 1024], node_idx=[2251], edge_idx=[7806])
+
+
+
+These subgraphs are now retrieved using a different retrieval method
+when compared to the original WebQSP dataset. Can we compare the
+properties of this method to the original WebQSPDataset’s retrieval
+method? Let’s compare some basics properties of the subgraphs:
+
+.. code:: python
+
+ def _eidx_helper(subg: Data, ground_truth: Data):
+ subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
+ if isinstance(subg_eidx, torch.Tensor):
+ subg_eidx = subg_eidx.tolist()
+ if isinstance(gt_eidx, torch.Tensor):
+ gt_eidx = gt_eidx.tolist()
+ subg_e = set(subg_eidx)
+ gt_e = set(gt_eidx)
+ return subg_e, gt_e
+ def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ total_e = set(range(num_edges))
+ tp = len(subg_e & gt_e)
+ tn = len(total_e-(subg_e | gt_e))
+ return (tp+tn)/num_edges
+ def check_retrieval_precision(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(subg_e)
+ def check_retrieval_recall(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(gt_e)
+
+
+ ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet)
+ num_edges = len(ds.indexer._edges)
+
+
+ for subg, ground_truth in tqdm.tqdm(zip((query_loader.query(q) for q in questions), ground_truth_graphs)):
+ print(f"Size: {len(subg.x)}, Ground Truth Size: {len(ground_truth.x)}, Accuracy: {check_retrieval_accuracy(subg, ground_truth, num_edges)}, Precision: {check_retrieval_precision(subg, ground_truth)}, Recall: {check_retrieval_recall(subg, ground_truth)}")
+
+ >>> Size: 2193, Ground Truth Size: 1709, Accuracy: 0.6636780705203827, Precision: 0.22923807012918535, Recall: 0.1994037381034285
+ >>> Size: 2682, Ground Truth Size: 1251, Accuracy: 0.7158736400576746, Precision: 0.10843513670738801, Recall: 0.22692963233503774
+ >>> Size: 2087, Ground Truth Size: 1285, Accuracy: 0.7979813868134749, Precision: 0.0547879177377892, Recall: 0.15757855822550831
+ >>> Size: 2975, Ground Truth Size: 1988, Accuracy: 0.6956088609254162, Precision: 0.14820555621795636, Recall: 0.21768826619964973
+ >>> Size: 2594, Ground Truth Size: 633, Accuracy: 0.78849128326124, Precision: 0.04202616198163095, Recall: 0.2032301480484522
+ >>> Size: 2462, Ground Truth Size: 1044, Accuracy: 0.7703499803381832, Precision: 0.07646643109540636, Recall: 0.19551861221539574
+ >>> Size: 2011, Ground Truth Size: 1382, Accuracy: 0.7871804954777821, Precision: 0.10117783355860205, Recall: 0.13142713819914723
+ >>> Size: 2011, Ground Truth Size: 1052, Accuracy: 0.802831301612269, Precision: 0.06452691407556001, Recall: 0.16702726092600606
+ >>> Size: 2892, Ground Truth Size: 1012, Accuracy: 0.7276182985974571, Precision: 0.10108615156751419, Recall: 0.20860927152317882
+ >>> Size: 1817, Ground Truth Size: 1978, Accuracy: 0.7530475815965395, Precision: 0.1677807486631016, Recall: 0.11696178937558248
+
+
+
+Note that, since we’re only comparing the results of 10 graphs here,
+this retrieval algorithm is not taking into account the full corpus of
+nodes in the dataset. If you want to see a full example, look at
+``rag_generate.py``, or ``rag_generate_multihop.py`` These examples
+generate datasets for the entirety of the WebQSP dataset, or the
+WikiData Multihop datasets that are discussed in Section 0.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 1c67eeddeec6..02315461c6ec 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -44,6 +44,7 @@ In addition, it consists of easy-to-use mini-batch loaders for operating on many
advanced/remote
advanced/graphgym
advanced/cpu_affinity
+ advanced/rag
.. toctree::
:maxdepth: 1
diff --git a/examples/llm/README.md b/examples/llm/README.md
index f1f01428d991..51043ae84a5c 100644
--- a/examples/llm/README.md
+++ b/examples/llm/README.md
@@ -1,5 +1,7 @@
# Examples for Co-training LLMs and GNNs
-| Example | Description |
-| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
+| Example | Description |
+| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
+| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
+| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md
new file mode 100644
index 000000000000..0500b4a7e5ce
--- /dev/null
+++ b/examples/llm/g_retriever_utils/README.md
@@ -0,0 +1,8 @@
+# Examples for LLM and GNN co-training
+
+| Example | Description |
+| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend |
+| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend |
+| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore |
+| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) |
diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py
new file mode 100644
index 000000000000..0f1c0e1b87ec
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_backend_utils.py
@@ -0,0 +1,224 @@
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+ Protocol,
+ Tuple,
+ Type,
+ runtime_checkable,
+)
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+
+from torch_geometric.data import (
+ FeatureStore,
+ GraphStore,
+ LargeGraphIndexer,
+ TripletLike,
+)
+from torch_geometric.data.large_graph_indexer import EDGE_RELATION
+from torch_geometric.distributed import (
+ LocalFeatureStore,
+ LocalGraphStore,
+ Partitioner,
+)
+from torch_geometric.typing import EdgeType, NodeType
+
+RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
+
+# TODO: Make everything compatible with Hetero graphs aswell
+
+
+# Adapted from LocalGraphStore
+@runtime_checkable
+class ConvertableGraphStore(Protocol):
+ @classmethod
+ def from_data(
+ cls,
+ edge_id: Tensor,
+ edge_index: Tensor,
+ num_nodes: int,
+ is_sorted: bool = False,
+ ) -> GraphStore:
+ ...
+
+ @classmethod
+ def from_hetero_data(
+ cls,
+ edge_id_dict: Dict[EdgeType, Tensor],
+ edge_index_dict: Dict[EdgeType, Tensor],
+ num_nodes_dict: Dict[NodeType, int],
+ is_sorted: bool = False,
+ ) -> GraphStore:
+ ...
+
+ @classmethod
+ def from_partition(cls, root: str, pid: int) -> GraphStore:
+ ...
+
+
+# Adapted from LocalFeatureStore
+@runtime_checkable
+class ConvertableFeatureStore(Protocol):
+ @classmethod
+ def from_data(
+ cls,
+ node_id: Tensor,
+ x: Optional[Tensor] = None,
+ y: Optional[Tensor] = None,
+ edge_id: Optional[Tensor] = None,
+ edge_attr: Optional[Tensor] = None,
+ ) -> FeatureStore:
+ ...
+
+ @classmethod
+ def from_hetero_data(
+ cls,
+ node_id_dict: Dict[NodeType, Tensor],
+ x_dict: Optional[Dict[NodeType, Tensor]] = None,
+ y_dict: Optional[Dict[NodeType, Tensor]] = None,
+ edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
+ edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
+ ) -> FeatureStore:
+ ...
+
+ @classmethod
+ def from_partition(cls, root: str, pid: int) -> FeatureStore:
+ ...
+
+
+class RemoteDataType(Enum):
+ DATA = auto()
+ PARTITION = auto()
+
+
+@dataclass
+class RemoteGraphBackendLoader:
+ """Utility class to load triplets into a RAG Backend."""
+ path: str
+ datatype: RemoteDataType
+ graph_store_type: Type[ConvertableGraphStore]
+ feature_store_type: Type[ConvertableFeatureStore]
+
+ def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
+ if self.datatype == RemoteDataType.DATA:
+ data_obj = torch.load(self.path)
+ graph_store = self.graph_store_type.from_data(
+ edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
+ num_nodes=data_obj.num_nodes)
+ feature_store = self.feature_store_type.from_data(
+ node_id=data_obj['node_id'], x=data_obj.x,
+ edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
+ elif self.datatype == RemoteDataType.PARTITION:
+ if pid is None:
+ assert pid is not None, \
+ "Partition ID must be defined for loading from a " \
+ + "partitioned store."
+ graph_store = self.graph_store_type.from_partition(self.path, pid)
+ feature_store = self.feature_store_type.from_partition(
+ self.path, pid)
+ else:
+ raise NotImplementedError
+ return (feature_store, graph_store)
+
+
+# TODO: make profilable
+def create_remote_backend_from_triplets(
+ triplets: Iterable[TripletLike], node_embedding_model: Module,
+ edge_embedding_model: Module | None = None,
+ graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
+ feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
+ node_method_to_call: str = "forward",
+ edge_method_to_call: str | None = None,
+ pre_transform: Callable[[TripletLike], TripletLike] | None = None,
+ path: str = '', n_parts: int = 1,
+ node_method_kwargs: Optional[Dict[str, Any]] = None,
+ edge_method_kwargs: Optional[Dict[str, Any]] = None
+) -> RemoteGraphBackendLoader:
+ """Utility function that can be used to create a RAG Backend from triplets.
+
+ Args:
+ triplets (Iterable[TripletLike]): Triplets to load into the RAG
+ Backend.
+ node_embedding_model (Module): Model to embed nodes into a feature
+ space.
+ edge_embedding_model (Module | None, optional): Model to embed edges
+ into a feature space. Defaults to the node model.
+ graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
+ use. Defaults to LocalGraphStore.
+ feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
+ class to use. Defaults to LocalFeatureStore.
+ node_method_to_call (str, optional): method to call for embeddings on
+ the node model. Defaults to "forward".
+ edge_method_to_call (str | None, optional): method to call for
+ embeddings on the edge model. Defaults to the node method.
+ pre_transform (Callable[[TripletLike], TripletLike] | None, optional):
+ optional preprocessing function for triplets. Defaults to None.
+ path (str, optional): path to save resulting stores. Defaults to ''.
+ n_parts (int, optional): Number of partitons to store in.
+ Defaults to 1.
+ node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass
+ into node encoding method. Defaults to None.
+ edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass
+ into edge encoding method. Defaults to None.
+
+ Returns:
+ RemoteGraphBackendLoader: Loader to load RAG backend from disk or
+ memory.
+ """
+ # Will return attribute errors for missing attributes
+ if not issubclass(graph_db, ConvertableGraphStore):
+ getattr(graph_db, "from_data")
+ getattr(graph_db, "from_hetero_data")
+ getattr(graph_db, "from_partition")
+ elif not issubclass(feature_db, ConvertableFeatureStore):
+ getattr(feature_db, "from_data")
+ getattr(feature_db, "from_hetero_data")
+ getattr(feature_db, "from_partition")
+
+ # Resolve callable methods
+ node_method_kwargs = node_method_kwargs \
+ if node_method_kwargs is not None else dict()
+
+ edge_embedding_model = edge_embedding_model \
+ if edge_embedding_model is not None else node_embedding_model
+ edge_method_to_call = edge_method_to_call \
+ if edge_method_to_call is not None else node_method_to_call
+ edge_method_kwargs = edge_method_kwargs \
+ if edge_method_kwargs is not None else node_method_kwargs
+
+ # These will return AttributeErrors if they don't exist
+ node_model = getattr(node_embedding_model, node_method_to_call)
+ edge_model = getattr(edge_embedding_model, edge_method_to_call)
+
+ indexer = LargeGraphIndexer.from_triplets(triplets,
+ pre_transform=pre_transform)
+
+ node_feats = node_model(indexer.get_node_features(), **node_method_kwargs)
+ indexer.add_node_feature('x', node_feats)
+
+ edge_feats = edge_model(
+ indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
+ **edge_method_kwargs)
+ indexer.add_edge_feature(new_feature_name="edge_attr",
+ new_feature_vals=edge_feats,
+ map_from_feature=EDGE_RELATION)
+
+ data = indexer.to_data(node_feature_name='x',
+ edge_feature_name='edge_attr')
+
+ if n_parts == 1:
+ torch.save(data, path)
+ return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
+ feature_db)
+ else:
+ partitioner = Partitioner(data=data, num_parts=n_parts, root=path)
+ partitioner.generate_partition()
+ return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
+ graph_db, feature_db)
diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py
new file mode 100644
index 000000000000..e01e9e59bb88
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_feature_store.py
@@ -0,0 +1,189 @@
+import gc
+from collections.abc import Iterable, Iterator
+from typing import Any, Dict, Optional, Type, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+from torchmetrics.functional import pairwise_cosine_similarity
+
+from torch_geometric.data import Data, HeteroData
+from torch_geometric.distributed import LocalFeatureStore
+from torch_geometric.nn.nlp import SentenceTransformer
+from torch_geometric.nn.pool import ApproxMIPSKNNIndex
+from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
+from torch_geometric.typing import InputEdges, InputNodes
+
+
+# NOTE: Only compatible with Homogeneous graphs for now
+class KNNRAGFeatureStore(LocalFeatureStore):
+ def __init__(self, enc_model: Type[Module],
+ model_kwargs: Optional[Dict[str,
+ Any]] = None, *args, **kwargs):
+ self.device = torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu")
+ self.enc_model = enc_model(*args, **kwargs).to(self.device)
+ self.enc_model.eval()
+ self.model_kwargs = \
+ model_kwargs if model_kwargs is not None else dict()
+ super().__init__()
+
+ @property
+ def x(self) -> Tensor:
+ return self.get_tensor(group_name=None, attr_name='x')
+
+ @property
+ def edge_attr(self) -> Tensor:
+ return self.get_tensor(group_name=(None, None), attr_name='edge_attr')
+
+ def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes:
+ result = next(self._retrieve_seed_nodes_batch([query], k_nodes))
+ gc.collect()
+ torch.cuda.empty_cache()
+ return result
+
+ def _retrieve_seed_nodes_batch(self, query: Iterable[Any],
+ k_nodes: int) -> Iterator[InputNodes]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ query_enc = self.enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device))
+ topk = min(k_nodes, len(self.x))
+ for q in prizes:
+ _, indices = torch.topk(q, topk, largest=True)
+ yield indices
+
+ def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges:
+ result = next(self._retrieve_seed_edges_batch([query], k_edges))
+ gc.collect()
+ torch.cuda.empty_cache()
+ return result
+
+ def _retrieve_seed_edges_batch(self, query: Iterable[Any],
+ k_edges: int) -> Iterator[InputEdges]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ query_enc = self.enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+
+ prizes = pairwise_cosine_similarity(query_enc,
+ self.edge_attr.to(self.device))
+ topk = min(k_edges, len(self.edge_attr))
+ for q in prizes:
+ _, indices = torch.topk(q, topk, largest=True)
+ yield indices
+
+ def load_subgraph(
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
+ ) -> Union[Data, HeteroData]:
+
+ if isinstance(sample, HeteroSamplerOutput):
+ raise NotImplementedError
+
+ # NOTE: torch_geometric.loader.utils.filter_custom_store can be used
+ # here if it supported edge features
+ node_id = sample.node
+ edge_id = sample.edge
+ edge_index = torch.stack((sample.row, sample.col), dim=0)
+ x = self.x[node_id]
+ edge_attr = self.edge_attr[edge_id]
+
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
+ node_idx=node_id, edge_idx=edge_id)
+
+
+# TODO: Refactor because composition >> inheritance
+
+
+def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor,
+ device: torch.device, batch_size: int = 2**20):
+ """Add new features to the existing KNN index in batches.
+
+ Args:
+ knn_index (ApproxMIPSKNNIndex): Index to add features to.
+ emb (Tensor): Embeddings to add.
+ device (torch.device): Device to store in
+ batch_size (int, optional): Batch size to iterate by.
+ Defaults to 2**20, which equates to 4GB if working with
+ 1024 dim floats.
+ """
+ for i in range(0, emb.size(0), batch_size):
+ if emb.size(0) - i >= batch_size:
+ emb_batch = emb[i:i + batch_size].to(device)
+ else:
+ emb_batch = emb[i:].to(device)
+ knn_index.add(emb_batch)
+
+
+class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore):
+ def __init__(self, enc_model: Type[Module],
+ model_kwargs: Optional[Dict[str,
+ Any]] = None, *args, **kwargs):
+ # TODO: Add kwargs for approx KNN to parameters here.
+ super().__init__(enc_model, model_kwargs, *args, **kwargs)
+ self.node_knn_index = None
+ self.edge_knn_index = None
+
+ def _retrieve_seed_nodes_batch(self, query: Iterable[Any],
+ k_nodes: int) -> Iterator[InputNodes]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ enc_model = self.enc_model.to(self.device)
+ query_enc = enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ del enc_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.node_knn_index is None:
+ self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100,
+ num_cells_to_visit=100,
+ bits_per_vector=4)
+ # Need to add in batches to avoid OOM
+ _add_features_to_knn_index(self.node_knn_index, self.x,
+ self.device)
+
+ output = self.node_knn_index.search(query_enc, k=k_nodes)
+ yield from output.index
+
+ def _retrieve_seed_edges_batch(self, query: Iterable[Any],
+ k_edges: int) -> Iterator[InputEdges]:
+ if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
+ raise NotImplementedError
+
+ enc_model = self.enc_model.to(self.device)
+ query_enc = enc_model.encode(query,
+ **self.model_kwargs).to(self.device)
+ del enc_model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.edge_knn_index is None:
+ self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100,
+ num_cells_to_visit=100,
+ bits_per_vector=4)
+ # Need to add in batches to avoid OOM
+ _add_features_to_knn_index(self.edge_knn_index, self.edge_attr,
+ self.device)
+
+ output = self.edge_knn_index.search(query_enc, k=k_edges)
+ yield from output.index
+
+
+# TODO: These two classes should be refactored
+class SentenceTransformerFeatureStore(KNNRAGFeatureStore):
+ def __init__(self, *args, **kwargs):
+ kwargs['model_name'] = kwargs.get(
+ 'model_name', 'sentence-transformers/all-roberta-large-v1')
+ super().__init__(SentenceTransformer, *args, **kwargs)
+
+
+class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore):
+ def __init__(self, *args, **kwargs):
+ kwargs['model_name'] = kwargs.get(
+ 'model_name', 'sentence-transformers/all-roberta-large-v1')
+ super().__init__(SentenceTransformer, *args, **kwargs)
diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py
new file mode 100644
index 000000000000..c6895b453b0c
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_generate.py
@@ -0,0 +1,137 @@
+# %%
+import argparse
+from itertools import chain
+from typing import Tuple
+
+import pandas as pd
+import torch
+import tqdm
+from rag_backend_utils import create_remote_backend_from_triplets
+from rag_feature_store import SentenceTransformerFeatureStore
+from rag_graph_store import NeighborSamplingRAGGraphStore
+
+from torch_geometric.data import Data
+from torch_geometric.datasets import WebQSPDataset
+from torch_geometric.datasets.web_qsp_dataset import (
+ preprocess_triplet,
+ retrieval_via_pcst,
+)
+from torch_geometric.loader import RAGQueryLoader
+from torch_geometric.nn.nlp import SentenceTransformer
+
+# %%
+parser = argparse.ArgumentParser(description="""Generate new WebQSP subgraphs
+NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models."""
+ )
+# TODO: Add more arguments for configuring rag params
+parser.add_argument("--use_pcst", action="store_true")
+parser.add_argument("--num_samples", type=int, default=4700)
+parser.add_argument("--out_file", default="subg_results.pt")
+args = parser.parse_args()
+
+# %%
+ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True,
+ force_reload=True)
+
+# %%
+triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset)
+
+# %%
+questions = ds.raw_dataset['question']
+
+# %%
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = SentenceTransformer(
+ model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+# %%
+fs, gs = create_remote_backend_from_triplets(
+ triplets=triplets, node_embedding_model=model,
+ node_method_to_call="encode", path="backend",
+ pre_transform=preprocess_triplet, node_method_kwargs={
+ "batch_size": 256
+ }, graph_db=NeighborSamplingRAGGraphStore,
+ feature_db=SentenceTransformerFeatureStore).load()
+
+# %%
+
+
+def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
+ topk_e: int = 3,
+ cost_e: float = 0.5) -> Tuple[Data, str]:
+ q_emb = model.encode(query)
+ textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
+ out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
+ textual_edges, topk, topk_e, cost_e)
+ out_graph["desc"] = desc
+ return out_graph
+
+
+def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]:
+ textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
+ desc = (
+ textual_nodes.to_csv(index=False) + "\n" +
+ textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"]))
+ graph["desc"] = desc
+ return graph
+
+
+transform = apply_retrieval_via_pcst \
+ if args.use_pcst else apply_retrieval_with_text
+
+query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5},
+ seed_edges_kwargs={"k_edges": 5},
+ sampler_kwargs={"num_neighbors": [50] * 2},
+ local_filter=transform)
+
+
+# %%
+# Accuracy Metrics to be added to Profiler
+def _eidx_helper(subg: Data, ground_truth: Data):
+ subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
+ if isinstance(subg_eidx, torch.Tensor):
+ subg_eidx = subg_eidx.tolist()
+ if isinstance(gt_eidx, torch.Tensor):
+ gt_eidx = gt_eidx.tolist()
+ subg_e = set(subg_eidx)
+ gt_e = set(gt_eidx)
+ return subg_e, gt_e
+
+
+def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ total_e = set(range(num_edges))
+ tp = len(subg_e & gt_e)
+ tn = len(total_e - (subg_e | gt_e))
+ return (tp + tn) / num_edges
+
+
+def check_retrieval_precision(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(subg_e)
+
+
+def check_retrieval_recall(subg: Data, ground_truth: Data):
+ subg_e, gt_e = _eidx_helper(subg, ground_truth)
+ return len(subg_e & gt_e) / len(gt_e)
+
+
+# %%
+retrieval_stats = {"precision": [], "recall": [], "accuracy": []}
+subgs = []
+node_len = []
+edge_len = []
+for subg in tqdm.tqdm(query_loader.query(q) for q in questions):
+ subgs.append(subg)
+ node_len.append(subg['x'].shape[0])
+ edge_len.append(subg['edge_attr'].shape[0])
+
+for i, subg in enumerate(subgs):
+ subg['question'] = questions[i]
+ subg['label'] = ds[i]['label']
+
+pd.DataFrame.from_dict(retrieval_stats).to_csv(
+ args.out_file.split('.')[0] + '_metadata.csv')
+torch.save(subgs, args.out_file)
diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py
new file mode 100644
index 000000000000..48473f287233
--- /dev/null
+++ b/examples/llm/g_retriever_utils/rag_graph_store.py
@@ -0,0 +1,107 @@
+from typing import Optional, Union
+
+import torch
+from torch import Tensor
+
+from torch_geometric.data import FeatureStore
+from torch_geometric.distributed import LocalGraphStore
+from torch_geometric.sampler import (
+ HeteroSamplerOutput,
+ NeighborSampler,
+ NodeSamplerInput,
+ SamplerOutput,
+)
+from torch_geometric.sampler.neighbor_sampler import NumNeighborsType
+from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes
+
+
+class NeighborSamplingRAGGraphStore(LocalGraphStore):
+ def __init__(self, feature_store: Optional[FeatureStore] = None,
+ num_neighbors: NumNeighborsType = [1], **kwargs):
+ self.feature_store = feature_store
+ self._num_neighbors = num_neighbors
+ self.sample_kwargs = kwargs
+ self._sampler_is_initialized = False
+ super().__init__()
+
+ def _init_sampler(self):
+ if self.feature_store is None:
+ raise AttributeError("Feature store not registered yet.")
+ self.sampler = NeighborSampler(data=(self.feature_store, self),
+ num_neighbors=self._num_neighbors,
+ **self.sample_kwargs)
+ self._sampler_is_initialized = True
+
+ def register_feature_store(self, feature_store: FeatureStore):
+ self.feature_store = feature_store
+ self._sampler_is_initialized = False
+
+ def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool:
+ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
+ self._sampler_is_initialized = False
+ return ret
+
+ @property
+ def edge_index(self):
+ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
+
+ def put_edge_index(self, edge_index: EdgeTensorType, *args,
+ **kwargs) -> bool:
+ ret = super().put_edge_index(edge_index, *args, **kwargs)
+ # HACK
+ self.edge_idx_args = args
+ self.edge_idx_kwargs = kwargs
+ self._sampler_is_initialized = False
+ return ret
+
+ @property
+ def num_neighbors(self):
+ return self._num_neighbors
+
+ @num_neighbors.setter
+ def num_neighbors(self, num_neighbors: NumNeighborsType):
+ self._num_neighbors = num_neighbors
+ if hasattr(self, 'sampler'):
+ self.sampler.num_neighbors = num_neighbors
+
+ def sample_subgraph(
+ self, seed_nodes: InputNodes, seed_edges: InputEdges,
+ num_neighbors: Optional[NumNeighborsType] = None
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
+ """Sample the graph starting from the given nodes and edges using the
+ in-built NeighborSampler.
+
+ Args:
+ seed_nodes (InputNodes): Seed nodes to start sampling from.
+ seed_edges (InputEdges): Seed edges to start sampling from.
+ num_neighbors (Optional[NumNeighborsType], optional): Parameters
+ to determine how many hops and number of neighbors per hop.
+ Defaults to None.
+
+ Returns:
+ Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
+ for the input.
+ """
+ if not self._sampler_is_initialized:
+ self._init_sampler()
+ if num_neighbors is not None:
+ self.num_neighbors = num_neighbors
+
+ # FIXME: Right now, only input nodes/edges as tensors are be supported
+ if not isinstance(seed_nodes, Tensor):
+ raise NotImplementedError
+ if not isinstance(seed_edges, Tensor):
+ raise NotImplementedError
+ device = seed_nodes.device
+
+ # TODO: Call sample_from_edges for seed_edges
+ # Turning them into nodes for now.
+ seed_edges = self.edge_index.to(device).T[seed_edges.to(
+ device)].reshape(-1)
+ seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0)
+
+ seed_nodes = seed_nodes.unique().contiguous()
+ node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
+ out = self.sampler.sample_from_nodes(node_sample_input)
+
+ return out
diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md
new file mode 100644
index 000000000000..ff43b16a2c05
--- /dev/null
+++ b/examples/llm/multihop_rag/README.md
@@ -0,0 +1,9 @@
+# Examples for LLM and GNN co-training
+
+| Example | Description |
+| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
+| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. |
+| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. |
+| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. |
+
+NOTE: Performance of GRetriever on this dataset has not been evaluated.
diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh
new file mode 100644
index 000000000000..3c1970d39440
--- /dev/null
+++ b/examples/llm/multihop_rag/multihop_download.sh
@@ -0,0 +1,12 @@
+#!/bin/sh
+
+# Wikidata5m
+
+wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz"
+tar -xvf "wikidata5m_alias.tar.gz"
+wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz"
+gzip -d "wikidata5m_all_triplet.txt.gz" -f
+
+# 2Multihopqa
+wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip"
+unzip -o "data_ids_april7.zip"
diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py
new file mode 100644
index 000000000000..46052bdf1b15
--- /dev/null
+++ b/examples/llm/multihop_rag/multihop_preprocess.py
@@ -0,0 +1,276 @@
+"""Example workflow for downloading and assembling a multihop QA dataset."""
+
+import argparse
+import json
+from subprocess import call
+
+import pandas as pd
+import torch
+import tqdm
+
+from torch_geometric.data import LargeGraphIndexer
+
+# %% [markdown]
+# # Encoding A Large Knowledge Graph Part 2
+
+# %% [markdown]
+# In this notebook, we will continue where we left off by building a new
+# multi-hop QA dataset based on Wikidata.
+
+# %% [markdown]
+# ## Example 2: Building a new Dataset from Questions and an already-existing
+# Knowledge Graph
+
+# %% [markdown]
+# ### Motivation
+
+# %% [markdown]
+# One potential application of knowledge graph structural encodings is
+# capturing the relationships between different entities that are multiple
+# hops apart. This can be challenging for an LLM to recognize from prepended
+# graph information. Here's a motivating example (credit to @Rishi Puri):
+
+# %% [markdown]
+# In this example, the question can only be answered by reasoning about the
+# relationships between the entities in the knowledge graph.
+
+# %% [markdown]
+# ### Building a Multi-Hop QA Dataset
+
+# %% [markdown]
+# To start, we need to download the raw data of a knowledge graph.
+# In this case, we use WikiData5M
+# ([Wang et al]
+# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)).
+# Here we download the raw triplets and their entity codes. Information about
+# this dataset can be found
+# [here](https://deepgraphlearning.github.io/project/wikidata5m).
+
+# %% [markdown]
+# The following download contains the ID to plaintext mapping for all the
+# entities and relations in the knowledge graph:
+
+rv = call("./multihop_download.sh")
+
+# %% [markdown]
+# To start, we are going to preprocess the knowledge graph to substitute each
+# of the entity/relation codes with their plaintext aliases. This makes it
+# easier to use a pre-trained textual encoding model to create triplet
+# embeddings, as such a model likely won't understand how to properly embed
+# the entity codes.
+
+# %%
+
+# %%
+parser = argparse.ArgumentParser(description="Preprocess wikidata5m")
+parser.add_argument("--n_triplets", type=int, default=-1)
+args = parser.parse_args()
+
+# %%
+# Substitute entity codes with their aliases
+# Picking the first alias for each entity (rather arbitrarily)
+alias_map = {}
+rel_alias_map = {}
+for line in open('wikidata5m_entity.txt'):
+ parts = line.strip().split('\t')
+ entity_id = parts[0]
+ aliases = parts[1:]
+ alias_map[entity_id] = aliases[0]
+for line in open('wikidata5m_relation.txt'):
+ parts = line.strip().split('\t')
+ relation_id = parts[0]
+ relation_name = parts[1]
+ rel_alias_map[relation_id] = relation_name
+
+# %%
+full_graph = []
+missing_total = 0
+total = 0
+limit = None if args.n_triplets == -1 else args.n_triplets
+i = 0
+
+for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')):
+ if limit is not None and i >= limit:
+ break
+ src, rel, dst = line.strip().split('\t')
+ if src not in alias_map:
+ missing_total += 1
+ if dst not in alias_map:
+ missing_total += 1
+ if rel not in rel_alias_map:
+ missing_total += 1
+ total += 3
+ full_graph.append([
+ alias_map.get(src, src),
+ rel_alias_map.get(rel, rel),
+ alias_map.get(dst, dst)
+ ])
+ i += 1
+print(f"Missing aliases: {missing_total}/{total}")
+
+# %% [markdown]
+# Now `full_graph` represents the knowledge graph triplets in
+# understandable plaintext.
+
+# %% [markdown]
+# Next, we need a set of multi-hop questions that the Knowledge Graph will
+# provide us with context for. We utilize a subset of
+# [HotPotQA](https://hotpotqa.github.io/)
+# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called
+# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop)
+# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)),
+# which includes a subgraph of entities that serve as the ground truth
+# justification for answering each multi-hop question:
+
+# %%
+with open('train.json') as f:
+ train_data = json.load(f)
+train_df = pd.DataFrame(train_data)
+train_df['split_type'] = 'train'
+
+with open('dev.json') as f:
+ dev_data = json.load(f)
+dev_df = pd.DataFrame(dev_data)
+dev_df['split_type'] = 'dev'
+
+with open('test.json') as f:
+ test_data = json.load(f)
+test_df = pd.DataFrame(test_data)
+test_df['split_type'] = 'test'
+
+df = pd.concat([train_df, dev_df, test_df])
+
+# %% [markdown]
+# Now we need to extract the subgraphs
+
+# %%
+df['graph_size'] = df['evidences_id'].apply(lambda row: len(row))
+
+# %% [markdown]
+# (Optional) We take only questions where the evidence graph is greater than
+# 0. (Note: this gets rid of the test set):
+
+# %%
+# df = df[df['graph_size'] > 0]
+
+# %%
+refined_df = df[[
+ '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type',
+ 'graph_size'
+]]
+
+# %% [markdown]
+# Checkpoint:
+
+# %%
+refined_df.to_csv('wikimultihopqa_refined.csv', index=False)
+
+# %% [markdown]
+# Now we need to check that all the entities mentioned in the question/answer
+# set are also present in the Wikidata graph:
+
+# %%
+relation_map = {}
+with open('wikidata5m_relation.txt') as f:
+ for line in tqdm.tqdm(f):
+ parts = line.strip().split('\t')
+ for i in range(1, len(parts)):
+ if parts[i] not in relation_map:
+ relation_map[parts[i]] = []
+ relation_map[parts[i]].append(parts[0])
+
+# %%
+entity_set = set()
+with open('wikidata5m_entity.txt') as f:
+ for line in tqdm.tqdm(f):
+ entity_set.add(line.strip().split('\t')[0])
+
+# %%
+missing_entities = set()
+missing_entity_idx = set()
+for i, row in enumerate(refined_df.itertuples()):
+ for trip in row.evidences_id:
+ entities = trip[0], trip[2]
+ for entity in entities:
+ if entity not in entity_set:
+ # print(
+ # f'The following entity was not found in the KG: {entity}'
+ # )
+ missing_entities.add(entity)
+ missing_entity_idx.add(i)
+
+# %% [markdown]
+# Right now, we drop the missing entity entries. Additional preprocessing can
+# be done here to resolve the entity/relation collisions, but that is out of
+# the scope for this notebook.
+
+# %%
+# missing relations are ok, but missing entities cannot be mapped to
+# plaintext, so they should be dropped.
+refined_df.reset_index(inplace=True, drop=True)
+
+# %%
+cleaned_df = refined_df.drop(missing_entity_idx)
+
+# %% [markdown]
+# Now we save the resulting graph and questions/answers dataset:
+
+# %%
+cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False)
+
+# %%
+
+# %%
+torch.save(full_graph, 'wikimultihopqa_full_graph.pt')
+
+# %% [markdown]
+# ### Question: How do we extract a contextual subgraph for a given query?
+
+# %% [markdown]
+# The chosen retrieval algorithm is a critical component in the pipeline for
+# affecting RAG performance. In the next section (1), we will demonstrate a
+# naive method of retrieval for a large knowledge graph, and how to apply it
+# to this dataset along with WebQSP.
+
+# %% [markdown]
+# ### Preparing a Textualized Graph for LLM
+
+# %% [markdown]
+# For now however, we need to prepare the graph data to be used as a plaintext
+# prefix to the LLM. In order to do this, we want to prompt the LLM to use the
+# unique nodes, and unique edge triplets of a given subgraph. In order to do
+# this, we prepare a unique indexed node df and edge df for the knowledge
+# graph now. This process occurs trivially with the LargeGraphIndexer:
+
+# %%
+
+# %%
+indexer = LargeGraphIndexer.from_triplets(full_graph)
+
+# %%
+# Node DF
+textual_nodes = pd.DataFrame.from_dict(
+ {"node_attr": indexer.get_node_features()})
+textual_nodes["node_id"] = textual_nodes.index
+textual_nodes = textual_nodes[["node_id", "node_attr"]]
+
+# %% [markdown]
+# Notice how LargeGraphIndexer ensures that there are no duplicate indices:
+
+# %%
+# Edge DF
+textual_edges = pd.DataFrame(indexer.get_edge_features(),
+ columns=["src", "edge_attr", "dst"])
+textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]]
+textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]]
+
+# %% [markdown]
+# Note: The edge table refers to each node by its index in the node table.
+# We will see how this gets utilized later when indexing a subgraph.
+
+# %% [markdown]
+# Now we can save the result
+
+# %%
+textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False)
+textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False)
diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py
new file mode 100644
index 000000000000..de93a9e75dd1
--- /dev/null
+++ b/examples/llm/multihop_rag/rag_generate_multihop.py
@@ -0,0 +1,88 @@
+# %%
+import argparse
+import sys
+from typing import Tuple
+
+import pandas as pd
+import torch
+import tqdm
+
+from torch_geometric.data import Data
+from torch_geometric.datasets.web_qsp_dataset import (
+ preprocess_triplet,
+ retrieval_via_pcst,
+)
+from torch_geometric.loader import RAGQueryLoader
+from torch_geometric.nn.nlp import SentenceTransformer
+
+sys.path.append('..')
+
+from g_retriever_utils.rag_backend_utils import \
+ create_remote_backend_from_triplets # noqa: E402
+from g_retriever_utils.rag_feature_store import \
+ SentenceTransformerApproxFeatureStore # noqa: E402
+from g_retriever_utils.rag_graph_store import \
+ NeighborSamplingRAGGraphStore # noqa: E402
+
+# %%
+parser = argparse.ArgumentParser(
+ description="Generate new multihop dataset for rag")
+# TODO: Add more arguments for configuring rag params
+parser.add_argument("--num_samples", type=int)
+args = parser.parse_args()
+
+# %%
+triplets = torch.load('wikimultihopqa_full_graph.pt')
+
+# %%
+df = pd.read_csv('wikimultihopqa_cleaned.csv')
+questions = df['question'][:args.num_samples]
+labels = df['answer'][:args.num_samples]
+
+# %%
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = SentenceTransformer(
+ model_name='sentence-transformers/all-roberta-large-v1').to(device)
+
+# %%
+fs, gs = create_remote_backend_from_triplets(
+ triplets=triplets, node_embedding_model=model,
+ node_method_to_call="encode", path="backend",
+ pre_transform=preprocess_triplet, node_method_kwargs={
+ "batch_size": 256
+ }, graph_db=NeighborSamplingRAGGraphStore,
+ feature_db=SentenceTransformerApproxFeatureStore).load()
+
+# %%
+
+all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv')
+all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv')
+
+
+def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
+ topk_e: int = 3,
+ cost_e: float = 0.5) -> Tuple[Data, str]:
+ q_emb = model.encode(query)
+ textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index()
+ textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index()
+ out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
+ textual_edges, topk, topk_e, cost_e)
+ out_graph["desc"] = desc
+ return out_graph
+
+
+# %%
+query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10},
+ seed_edges_kwargs={"k_edges": 10},
+ sampler_kwargs={"num_neighbors": [40] * 3},
+ local_filter=apply_retrieval_via_pcst)
+
+# %%
+subgs = []
+for q, l in tqdm.tqdm(zip(questions, labels)):
+ subg = query_loader.query(q)
+ subg['question'] = q
+ subg['label'] = l
+ subgs.append(subg)
+
+torch.save(subgs, 'subg_results.pt')
diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py
index 266f498a113b..7e83c35befb6 100644
--- a/torch_geometric/loader/__init__.py
+++ b/torch_geometric/loader/__init__.py
@@ -22,6 +22,7 @@
from .prefetch import PrefetchLoader
from .cache import CachedLoader
from .mixin import AffinityMixin
+from .rag_loader import RAGQueryLoader
__all__ = classes = [
'DataLoader',
@@ -50,6 +51,7 @@
'PrefetchLoader',
'CachedLoader',
'AffinityMixin',
+ 'RAGQueryLoader',
]
RandomNodeSampler = deprecated(
diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py
new file mode 100644
index 000000000000..33d6cf0e868e
--- /dev/null
+++ b/torch_geometric/loader/rag_loader.py
@@ -0,0 +1,106 @@
+from abc import abstractmethod
+from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
+
+from torch_geometric.data import Data, FeatureStore, HeteroData
+from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
+from torch_geometric.typing import InputEdges, InputNodes
+
+
+class RAGFeatureStore(Protocol):
+ """Feature store for remote GNN RAG backend."""
+ @abstractmethod
+ def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
+ """Makes a comparison between the query and all the nodes to get all
+ the closest nodes. Return the indices of the nodes that are to be seeds
+ for the RAG Sampler.
+ """
+ ...
+
+ @abstractmethod
+ def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
+ """Makes a comparison between the query and all the edges to get all
+ the closest nodes. Returns the edge indices that are to be the seeds
+ for the RAG Sampler.
+ """
+ ...
+
+ @abstractmethod
+ def load_subgraph(
+ self, sample: Union[SamplerOutput, HeteroSamplerOutput]
+ ) -> Union[Data, HeteroData]:
+ """Combines sampled subgraph output with features in a Data object."""
+ ...
+
+
+class RAGGraphStore(Protocol):
+ """Graph store for remote GNN RAG backend."""
+ @abstractmethod
+ def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
+ **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
+ """Sample a subgraph using the seeded nodes and edges."""
+ ...
+
+ @abstractmethod
+ def register_feature_store(self, feature_store: FeatureStore):
+ """Register a feature store to be used with the sampler. Samplers need
+ info from the feature store in order to work properly on HeteroGraphs.
+ """
+ ...
+
+
+# TODO: Make compatible with Heterographs
+
+
+class RAGQueryLoader:
+ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
+ local_filter: Optional[Callable[[Data, Any], Data]] = None,
+ seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
+ seed_edges_kwargs: Optional[Dict[str, Any]] = None,
+ sampler_kwargs: Optional[Dict[str, Any]] = None,
+ loader_kwargs: Optional[Dict[str, Any]] = None):
+ """Loader meant for making queries from a remote backend.
+
+ Args:
+ data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
+ and GraphStore to load from. Assumed to conform to the
+ protocols listed above.
+ local_filter (Optional[Callable[[Data, Any], Data]], optional):
+ Optional local transform to apply to data after retrieval.
+ Defaults to None.
+ seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
+ to pass into process for fetching seed nodes. Defaults to None.
+ seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
+ to pass into process for fetching seed edges. Defaults to None.
+ sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
+ pass into process for sampling graph. Defaults to None.
+ loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
+ pass into process for loading graph features. Defaults to None.
+ """
+ fstore, gstore = data
+ self.feature_store = fstore
+ self.graph_store = gstore
+ self.graph_store.register_feature_store(self.feature_store)
+ self.local_filter = local_filter
+ self.seed_nodes_kwargs = seed_nodes_kwargs or {}
+ self.seed_edges_kwargs = seed_edges_kwargs or {}
+ self.sampler_kwargs = sampler_kwargs or {}
+ self.loader_kwargs = loader_kwargs or {}
+
+ def query(self, query: Any) -> Data:
+ """Retrieve a subgraph associated with the query with all its feature
+ attributes.
+ """
+ seed_nodes = self.feature_store.retrieve_seed_nodes(
+ query, **self.seed_nodes_kwargs)
+ seed_edges = self.feature_store.retrieve_seed_edges(
+ query, **self.seed_edges_kwargs)
+
+ subgraph_sample = self.graph_store.sample_subgraph(
+ seed_nodes, seed_edges, **self.sampler_kwargs)
+
+ data = self.feature_store.load_subgraph(sample=subgraph_sample,
+ **self.loader_kwargs)
+
+ if self.local_filter:
+ data = self.local_filter(data, query)
+ return data