diff --git a/CHANGELOG.md b/CHANGELOG.md index 86e06d8835e8..341be665fabf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730)) - Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) - Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) diff --git a/examples/llm/README.md b/examples/llm/README.md index eb471563de8e..4503e28ce6ee 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,8 +1,11 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | -| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | -| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | -| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| Example | Description | +| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | +| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | +| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 984ce3f010e7..a48901f1ff0e 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -11,6 +11,7 @@ https://github.com/neo4j-product-examples/neo4j-gnn-llm-example """ import argparse +import gc import math import os.path as osp import re @@ -145,6 +146,9 @@ def adjust_learning_rate(param_group, LR, epoch): test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) + # To clean up after Data Preproc + gc.collect() + torch.cuda.empty_cache() gnn = GAT( in_channels=1024, hidden_channels=hidden_channels, diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md new file mode 100644 index 000000000000..e072e6746b7c --- /dev/null +++ b/examples/llm/g_retriever_utils/README.md @@ -0,0 +1,11 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | +| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | +| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. | + +NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes. diff --git a/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py new file mode 100644 index 000000000000..6522aafca68b --- /dev/null +++ b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py @@ -0,0 +1,105 @@ +"""Used to benchmark the performance of an untuned/fine tuned LLM against +GRetriever with various architectures and layer depths. +""" +# %% +import argparse +import sys + +import torch + +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.nn.models import GAT, MLP, GRetriever + +sys.path.append('..') +from minimal_demo import ( # noqa: E402 # isort:skip + benchmark_models, get_loss, inference_step, +) + +# %% +parser = argparse.ArgumentParser( + description="""Benchmarker for GRetriever\n""" + + """NOTE: Evaluating with smaller samples may result in poorer""" + + """ performance for the trained models compared to """ + + """untrained models.""") +parser.add_argument("--hidden_channels", type=int, default=1024) +parser.add_argument("--learning_rate", type=float, default=1e-5) +parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--eval_batch_size", type=int, default=16) +parser.add_argument("--tiny_llama", action='store_true') + +parser.add_argument("--dataset_path", type=str, required=False) +# Default to WebQSP split +parser.add_argument("--num_train", type=int, default=2826) +parser.add_argument("--num_val", type=int, default=246) +parser.add_argument("--num_test", type=int, default=1628) + +args = parser.parse_args() + +# %% +hidden_channels = args.hidden_channels +lr = args.learning_rate +epochs = args.epochs +batch_size = args.batch_size +eval_batch_size = args.eval_batch_size + +# %% +if not args.dataset_path: + ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True) +else: + # We just assume that the size of the dataset accomodates the + # train/val/test split, because checking may be expensive. + dataset = torch.load(args.dataset_path) + + class MockDataset: + """Utility class to patch the fields in WebQSPDataset used by + GRetriever. + """ + def __init__(self) -> None: + pass + + @property + def split_idxs(self) -> dict: + # Imitates the WebQSP split method + return { + "train": + torch.arange(args.num_train), + "val": + torch.arange(args.num_val) + args.num_train, + "test": + torch.arange(args.num_test) + args.num_train + args.num_val, + } + + def __getitem__(self, idx: int): + return dataset[idx] + + ds = MockDataset() + +# %% +model_names = [] +model_classes = [] +model_kwargs = [] +model_type = ["GAT", "MLP"] +models = {"GAT": GAT, "MLP": MLP} +# Use to vary the depth of the GNN model +num_layers = [4] +# Use to vary the number of LLM tokens reserved for GNN output +num_tokens = [1] +for m_type in model_type: + for n_layer in num_layers: + for n_tokens in num_tokens: + model_names.append(f"{m_type}_{n_layer}_{n_tokens}") + model_classes.append(GRetriever) + kwargs = dict(gnn_hidden_channels=hidden_channels, + num_gnn_layers=n_layer, gnn_to_use=models[m_type], + mlp_out_tokens=n_tokens) + if args.tiny_llama: + kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1' + kwargs['mlp_out_dim'] = 2048 + kwargs['num_llm_params'] = 1 + model_kwargs.append(kwargs) + +# %% +benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs, + batch_size, eval_batch_size, get_loss, inference_step, + skip_LLMs=False, tiny_llama=args.tiny_llama, force=True) diff --git a/examples/llm/g_retriever_utils/minimal_demo.py b/examples/llm/g_retriever_utils/minimal_demo.py new file mode 100644 index 000000000000..bdd78c3180cb --- /dev/null +++ b/examples/llm/g_retriever_utils/minimal_demo.py @@ -0,0 +1,638 @@ +"""This example implements the G-Retriever model +(https://arxiv.org/abs/2402.07630) using PyG. + +G-Retriever significantly reduces hallucinations by 54% compared to the +stand-alone LLM baseline. + +Requirements: +`pip install datasets transformers pcst_fast sentencepiece accelerate` +""" +import argparse +import gc +import math +import multiprocessing as mp +import re +import sys +import time +from os import path +from typing import Any, Callable, Dict, List, Type + +import pandas as pd +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.data import Dataset +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn.models import GAT, GRetriever +from torch_geometric.nn.nlp import LLM + +# NOTE: This used to be merged in the G-Retriever example. +# FIXME: Getting the demos working like before is a WIP +sys.path.append('..') +from g_retriever import ( # noqa: E402 # isort:skip + compute_metrics, load_params_dict, save_params_dict, +) + + +def _detect_hallucinate(inp): + pred, label = inp + try: + split_pred = pred.split('[/s]')[0].strip().split('|') + correct_hit = len(re.findall(split_pred[0], label)) > 0 + correct_hit = correct_hit or any( + [label_i in pred.lower() for label_i in label.split('|')]) + hallucination = not correct_hit + return hallucination + except: # noqa + return "skip" + + +def detect_hallucinate(pred_batch, label_batch): + r"""An approximation for the unsolved task of detecting hallucinations. + We define a hallucination as an output that contains no instances of + acceptable label. + """ + with mp.Pool(len(pred_batch)) as p: + res = p.map(_detect_hallucinate, zip(pred_batch, label_batch)) + return res + + +def compute_n_parameters(model: torch.nn.Module) -> int: + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + + +def get_loss(model, batch, model_save_name) -> Tensor: + if model_save_name == 'llm': + return model(batch.question, batch.label, batch.desc) + else: + return model(batch.question, batch.x, batch.edge_index, batch.batch, + batch.label, batch.edge_attr, batch.desc) + + +def inference_step(model, batch, model_save_name): + if model_save_name == 'llm': + return model.inference(batch.question, batch.desc) + else: + return model.inference(batch.question, batch.x, batch.edge_index, + batch.batch, batch.edge_attr, batch.desc) + + +# TODO: Merge with G-Retriever example and make sure changes still work +def train( + num_epochs, + hidden_channels, + num_gnn_layers, + batch_size, + eval_batch_size, + lr, + checkpointing=False, + tiny_llama=False, + model=None, + dataset=None, + model_save_name=None, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + seed_everything(42) + if dataset is None: + dataset = WebQSPDataset() + gc.collect() + elif not isinstance(dataset, Dataset) and callable(dataset): + dataset = dataset() + gc.collect() + idx_split = dataset.split_idxs + + # Step 1: Build Node Classification Dataset + train_dataset = [dataset[i] for i in idx_split['train']] + val_dataset = [dataset[i] for i in idx_split['val']] + test_dataset = [dataset[i] for i in idx_split['test']] + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if model is None: + gc.collect() + gnn = GAT( + in_channels=1024, + hidden_channels=hidden_channels, + out_channels=1024, + num_layers=num_gnn_layers, + heads=4, + ) + if tiny_llama: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + ) + model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) + else: + llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) + model = GRetriever(llm=llm, gnn=gnn) + + if model_save_name is None: + model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm' + + model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' + if model_save_name == 'llm': + model = llm + + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05 + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = get_loss(model, batch, model_save_name) + loss.backward() + + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss = epoch_loss + float(loss) + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + print(epoch_str + f', Train Loss: {train_loss:4f}') + + val_loss = 0 + eval_output = [] + model.eval() + with torch.no_grad(): + for step, batch in enumerate(val_loader): + loss = get_loss(model, batch, model_save_name) + val_loss += loss.item() + val_loss = val_loss / len(val_loader) + print(epoch_str + f", Val Loss: {val_loss:4f}") + if checkpointing and val_loss < best_val_loss: + print("Checkpointing best model...") + best_val_loss = val_loss + best_epoch = epoch + save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt') + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + if checkpointing and best_epoch != num_epochs - 1: + print("Loading best checkpoint...") + model = load_params_dict( + model, + f'{model_save_name}_best_val_loss_ckpt.pt', + ) + + model.eval() + eval_output = [] + print("Final evaluation...") + progress_bar_test = tqdm(range(len(test_loader))) + for step, batch in enumerate(test_loader): + with torch.no_grad(): + pred = inference_step(model, batch, model_save_name) + eval_data = { + 'pred': pred, + 'question': batch.question, + 'desc': batch.desc, + 'label': batch.label + } + eval_output.append(eval_data) + progress_bar_test.update(1) + + # Step 6 Post-processing & compute metrics + compute_metrics(eval_output) + print(f"Total Training Time: {time.time() - start_time:2f}s") + save_params_dict(model, f'{model_save_name}.pt') + torch.save(eval_output, f'{model_save_name}_eval_outs.pt') + print("Done!") + return prep_time, dataset, eval_output # noqa: F821 + + +def _eval_hallucinations_on_loader(outs, loader, eval_batch_size): + model_save_list = [] + model_preds = [] + for out in outs: + model_preds += out['pred'] + for i, batch in enumerate(loader): + correct_answer = batch.label + + model_pred = model_preds[i * eval_batch_size:(i + 1) * eval_batch_size] + model_hallucinates = detect_hallucinate(model_pred, correct_answer) + model_save_list += [tup for tup in zip(model_pred, model_hallucinates)] + return model_save_list + + +def benchmark_models(models: List[Type[nn.Module]], model_names: List[str], + model_kwargs: List[Dict[str, Any]], dataset: Dataset, + lr: float, epochs: int, batch_size: int, + eval_batch_size: int, loss_fn: Callable, + inference_fn: Callable, skip_LLMs: bool = True, + tiny_llama: bool = False, checkpointing: bool = True, + force: bool = False, root_dir='.'): + """Utility function for creating a model benchmark for GRetriever that + grid searches over hyperparameters. Produces a DataFrame containing + metrics for each model. + + Args: + models (List[Type[nn.Module]]): Models to be benchmarked. + model_names (List[str]): Name of save files for model checkpoints + model_kwargs (List[Dict[str, Any]]): Parameters to use for each + particular model. + dataset (Dataset): Input dataset to train on. + lr (float): Learning rate + epochs (int): Number of epochs + batch_size (int): Batch size for training + eval_batch_size (int): Batch size for eval. Also determines + hallucination detection concurrancy. + loss_fn (Callable): Loss function + inference_fn (Callable): Inference function + skip_LLMs (bool, optional): Whether to skip LLM-only runs. + Defaults to True. + tiny_llama (bool, optional): Whether to use tiny llama as LLM. + Defaults to False. + checkpointing (bool, optional): Whether to checkpoint models. + Defaults to True. + force (bool, optional): Whether to rerun already existing results. + Defaults to False. + root_dir (str, optional): Dir to save results and checkpoints in. + Defaults to '.'. + """ + model_log: Dict[str, Dict[str, Any]] = dict() + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if not skip_LLMs: + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + + if force or not path.exists(root_dir + "/pure_llm_model_log.pt"): + model_log["pure_llm"] = dict() + + pure_preds = [] + for batch in tqdm(loader): + pure_llm_preds = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_preds += pure_llm_preds + pure_preds = [{"pred": pred} for pred in pure_preds] + + model_log["pure_llm"]["preds"] = pure_preds + model_log["pure_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(pure_preds, loader, + eval_batch_size) + model_log["pure_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["pure_llm"], + root_dir + "/pure_llm_model_log.pt") + else: + model_log["pure_llm"] = \ + torch.load(root_dir+"/pure_llm_model_log.pt") + + # LORA + if force or not path.exists(root_dir + "/tuned_llm_model_log.pt"): + model_log["tuned_llm"] = dict() + since = time.time() + gc.collect() + prep_time, _, lora_eval_outs = train(since, epochs, None, None, + batch_size, eval_batch_size, + lr, loss_fn, inference_fn, + model=pure_llm, + dataset=dataset) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log["tuned_llm"]["prep_time"] = prep_time + model_log["tuned_llm"]["e2e_time"] = e2e_time + model_log["tuned_llm"]["eval_output"] = lora_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + + model_log["tuned_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(lora_eval_outs, loader, + eval_batch_size) + model_log["tuned_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["tuned_llm"], + root_dir + "/tuned_llm_model_log.pt") + else: + model_log["tuned_llm"] = \ + torch.load(root_dir+"/tuned_llm_model_log.pt") + + del pure_llm + gc.collect() + + # All other models + for name, Model, kwargs in zip(model_names, models, model_kwargs): + model_log[name] = dict() + train_model = True + if path.exists(root_dir + f"/{name}.pt") and not force: + print(f"Model {name} appears to already exist.") + print("Would you like to retrain?") + train_model = str(input("(y/n):")).lower() == "y" + + if train_model: + since = time.time() + gc.collect() + model = Model(**kwargs) + prep_time, _, model_eval_outs = train( + since=since, num_epochs=epochs, hidden_channels=None, + num_gnn_layers=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, lr=lr, loss_fn=loss_fn, + inference_fn=inference_fn, checkpointing=checkpointing, + tiny_llama=tiny_llama, dataset=dataset, + model_save_name=root_dir + '/' + name, model=model) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log[name]["prep_time"] = prep_time + model_log[name]["e2e_time"] = e2e_time + model_log[name]["eval_output"] = model_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + model_log[name]["n_params"] = compute_n_parameters(model) + del model + gc.collect() + else: + model_eval_outs = torch.load(root_dir + f"/{name}_eval_outs.pt") + + # Calculate Hallucinations + skip_hallucination_detection = False + + if path.exists(root_dir + f"/{name}_model_log.pt") and not force: + print(f"Saved outputs for {name} have been found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_hallucination_detection = user_input != "y" + + if not skip_hallucination_detection: + model_save_list = _eval_hallucinations_on_loader( + model_eval_outs, loader, eval_batch_size) + + model_log[name]["hallucinates_list"] = model_save_list + torch.save(model_log[name], root_dir + f"/{name}_model_log.pt") + else: + model_log[name]["hallucinates_list"] = \ + torch.load( + root_dir+f"/{name}_model_log.pt" + )["hallucinates_list"] + + hal_dict = { + k: [tup[1] for tup in v["hallucinates_list"]] + for (k, v) in model_log.items() + } + hallucinates_df = pd.DataFrame(hal_dict).astype(str) + hallucinates_df = hallucinates_df.apply(pd.Series.value_counts).transpose() + hallucinates_df['e2e_time'] = pd.Series( + {k: v.get('e2e_time') + for (k, v) in model_log.items()}) + hallucinates_df['n_params'] = pd.Series( + {k: v.get('n_params') + for (k, v) in model_log.items()}) + print(hallucinates_df) + hallucinates_df.to_csv(root_dir + "/hallucinates_df.csv", index=False) + + +def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size, + eval_batch_size, loss_fn, inference_fn, + skip_pretrained_LLM=False, tiny_llama=False): + if not skip_pretrained_LLM: + print("First comparing against a pretrained LLM...") + # Step 1: Define a single batch size test loader + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + # batch size 1 loader for simplicity + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + if path.exists("demo_save_dict.pt"): + print("Saved outputs for the first step of the demo found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_step_one = user_input == "n" + else: + skip_step_one = False + + if not skip_step_one: + gnn_llm_hallucin_sum = 0 + pure_llm_hallucin_sum = 0 + gnn_save_list = [] + untuned_llm_save_list = [] + gnn_llm_preds = [] + for out in gnn_llm_eval_outs: + gnn_llm_preds += out['pred'] + if skip_pretrained_LLM: + print("Checking GNN+LLM for hallucinations...") + else: + print( + "Checking pretrained LLM vs trained GNN+LLM for hallucinations..." # noqa + ) + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + + gnn_llm_pred = gnn_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred, + correct_answer) + gnn_save_list += [ + tup for tup in zip(gnn_llm_pred, gnn_llm_hallucinates) + ] + + if not skip_pretrained_LLM: + # GNN+LLM only using 32 tokens to answer. + # Allow more output tokens for untrained LLM + pure_llm_pred = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_llm_hallucinates = detect_hallucinate( + pure_llm_pred, correct_answer) + else: + pure_llm_pred = [''] * len(gnn_llm_hallucinates) + pure_llm_hallucinates = [False] * len(gnn_llm_hallucinates) + untuned_llm_save_list += [ + tup for tup in zip(pure_llm_pred, pure_llm_hallucinates) + ] + + for gnn_llm_hal, pure_llm_hal in zip(gnn_llm_hallucinates, + pure_llm_hallucinates): + if gnn_llm_hal == "skip" or pure_llm_hal == "skip": # noqa + # skipping when hallucination is hard to eval + continue + gnn_llm_hallucin_sum += int(gnn_llm_hal) + pure_llm_hallucin_sum += int(pure_llm_hal) + if not skip_pretrained_LLM: + print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2) + print(f"GNN reduces pretrained LLM hallucinations by: ~{percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Now we see how the LLM compares when finetuned...") + print("Saving outputs of GNN+LLM and pretrained LLM...") + save_dict = { + "gnn_save_list": gnn_save_list, + "untuned_llm_save_list": untuned_llm_save_list, + "gnn_llm_hallucin_sum": gnn_llm_hallucin_sum, + "pure_llm_hallucin_sum": pure_llm_hallucin_sum + } + torch.save(save_dict, "demo_save_dict.pt") + print("Done!") + else: + save_dict = torch.load("demo_save_dict.pt") + gnn_save_list = save_dict["gnn_save_list"] + untuned_llm_save_list = save_dict["untuned_llm_save_list"] + gnn_llm_hallucin_sum = save_dict["gnn_llm_hallucin_sum"] + pure_llm_hallucin_sum = save_dict["pure_llm_hallucin_sum"] + + trained_llm_hallucin_sum = 0 + untuned_llm_hallucin_sum = pure_llm_hallucin_sum + final_prnt_str = "" + if path.exists("llm.pt") and path.exists("llm_eval_outs.pt"): + print("Existing finetuned LLM found.") + print("Would you like to retrain?") + user_input = str(input("(y/n):")).lower() + retrain = user_input == "y" + else: + retrain = True + if retrain: + print("Finetuning LLM...") + since = time.time() + _, _, pure_llm_eval_outputs = train(since, epochs, None, None, + batch_size, eval_batch_size, lr, + loss_fn, inference_fn, + model=pure_llm, dataset=dataset) + e2e_time = round(time.time() - since, 2) + print("E2E time (e2e_time) =", e2e_time, "seconds") + else: + pure_llm_eval_outputs = torch.load("llm_eval_outs.pt") + pure_llm_preds = [] + for out in pure_llm_eval_outputs: + pure_llm_preds += out['pred'] + print("Final comparison between all models...") + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + gnn_llm_pred, gnn_llm_hallucinates = list( + zip(*gnn_save_list[i * eval_batch_size:(i + 1) * eval_batch_size])) + untuned_llm_pred, untuned_llm_hallucinates = list( + zip(*untuned_llm_save_list[i * eval_batch_size:(i + 1) * + eval_batch_size])) + pure_llm_pred = pure_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + pure_llm_hallucinates = detect_hallucinate(pure_llm_pred, + correct_answer) + for j in range(len(gnn_llm_pred)): + if skip_pretrained_LLM: + # we did not check the untrained LLM, so do not decide to demo + # based on this. + # HACK + untuned_llm_hallucinates = {j: True} + if gnn_llm_hallucinates[j] == "skip" or untuned_llm_hallucinates[ + j] == "skip" or pure_llm_hallucinates[j] == "skip": + continue + trained_llm_hallucin_sum += int(pure_llm_hallucinates[j]) + if untuned_llm_hallucinates[j] and pure_llm_hallucinates[ + j] and not gnn_llm_hallucinates[j]: # noqa + final_prnt_str += "Prompt: '" + question[j] + "'\n" + final_prnt_str += "Label: '" + correct_answer[j] + "'\n" + if not skip_pretrained_LLM: + final_prnt_str += "Untuned LLM Output: '" \ + + untuned_llm_pred[j] + "'\n" # noqa + final_prnt_str += "Tuned LLM Output: '" + pure_llm_pred[ + j] + "'\n" + final_prnt_str += "GNN+LLM Output: '" + gnn_llm_pred[j] + "'\n" + final_prnt_str += "\n" + "#" * 20 + "\n\n" + if not skip_pretrained_LLM: + print("Total untuned LLM Hallucinations:", untuned_llm_hallucin_sum) + print("Total tuned LLM Hallucinations:", trained_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + if not skip_pretrained_LLM: + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / untuned_llm_hallucin_sum), 2) + print(f"GNN reduces untuned LLM hallucinations by: ~{percent}%") + tuned_percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / trained_llm_hallucin_sum), 2) + print(f"GNN reduces tuned LLM hallucinations by: ~{tuned_percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Potential instances where GNN solves the hallucinations of LLM:") + print(final_prnt_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gnn_hidden_channels', type=int, default=1024) + parser.add_argument('--num_gnn_layers', type=int, default=4) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--epochs', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--eval_batch_size', type=int, default=16) + parser.add_argument('--checkpointing', action='store_true') + parser.add_argument('--tiny_llama', action='store_true') + parser.add_argument( + "--skip_pretrained_llm_eval", action="store_true", + help="This flag will skip the evaluation of the pretrained LLM.") + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.gnn_hidden_channels, + args.num_gnn_layers, + args.batch_size, + args.eval_batch_size, + args.lr, + checkpointing=args.checkpointing, + tiny_llama=args.tiny_llama, + ) + print(f"Total Time: {time.time() - start_time:2f}s") diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py new file mode 100644 index 000000000000..0f1c0e1b87ec --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_backend_utils.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Protocol, + Tuple, + Type, + runtime_checkable, +) + +import torch +from torch import Tensor +from torch.nn import Module + +from torch_geometric.data import ( + FeatureStore, + GraphStore, + LargeGraphIndexer, + TripletLike, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION +from torch_geometric.distributed import ( + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.typing import EdgeType, NodeType + +RemoteGraphBackend = Tuple[FeatureStore, GraphStore] + +# TODO: Make everything compatible with Hetero graphs aswell + + +# Adapted from LocalGraphStore +@runtime_checkable +class ConvertableGraphStore(Protocol): + @classmethod + def from_data( + cls, + edge_id: Tensor, + edge_index: Tensor, + num_nodes: int, + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_hetero_data( + cls, + edge_id_dict: Dict[EdgeType, Tensor], + edge_index_dict: Dict[EdgeType, Tensor], + num_nodes_dict: Dict[NodeType, int], + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> GraphStore: + ... + + +# Adapted from LocalFeatureStore +@runtime_checkable +class ConvertableFeatureStore(Protocol): + @classmethod + def from_data( + cls, + node_id: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + edge_id: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_hetero_data( + cls, + node_id_dict: Dict[NodeType, Tensor], + x_dict: Optional[Dict[NodeType, Tensor]] = None, + y_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> FeatureStore: + ... + + +class RemoteDataType(Enum): + DATA = auto() + PARTITION = auto() + + +@dataclass +class RemoteGraphBackendLoader: + """Utility class to load triplets into a RAG Backend.""" + path: str + datatype: RemoteDataType + graph_store_type: Type[ConvertableGraphStore] + feature_store_type: Type[ConvertableFeatureStore] + + def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: + if self.datatype == RemoteDataType.DATA: + data_obj = torch.load(self.path) + graph_store = self.graph_store_type.from_data( + edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, + num_nodes=data_obj.num_nodes) + feature_store = self.feature_store_type.from_data( + node_id=data_obj['node_id'], x=data_obj.x, + edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) + elif self.datatype == RemoteDataType.PARTITION: + if pid is None: + assert pid is not None, \ + "Partition ID must be defined for loading from a " \ + + "partitioned store." + graph_store = self.graph_store_type.from_partition(self.path, pid) + feature_store = self.feature_store_type.from_partition( + self.path, pid) + else: + raise NotImplementedError + return (feature_store, graph_store) + + +# TODO: make profilable +def create_remote_backend_from_triplets( + triplets: Iterable[TripletLike], node_embedding_model: Module, + edge_embedding_model: Module | None = None, + graph_db: Type[ConvertableGraphStore] = LocalGraphStore, + feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, + node_method_to_call: str = "forward", + edge_method_to_call: str | None = None, + pre_transform: Callable[[TripletLike], TripletLike] | None = None, + path: str = '', n_parts: int = 1, + node_method_kwargs: Optional[Dict[str, Any]] = None, + edge_method_kwargs: Optional[Dict[str, Any]] = None +) -> RemoteGraphBackendLoader: + """Utility function that can be used to create a RAG Backend from triplets. + + Args: + triplets (Iterable[TripletLike]): Triplets to load into the RAG + Backend. + node_embedding_model (Module): Model to embed nodes into a feature + space. + edge_embedding_model (Module | None, optional): Model to embed edges + into a feature space. Defaults to the node model. + graph_db (Type[ConvertableGraphStore], optional): GraphStore class to + use. Defaults to LocalGraphStore. + feature_db (Type[ConvertableFeatureStore], optional): FeatureStore + class to use. Defaults to LocalFeatureStore. + node_method_to_call (str, optional): method to call for embeddings on + the node model. Defaults to "forward". + edge_method_to_call (str | None, optional): method to call for + embeddings on the edge model. Defaults to the node method. + pre_transform (Callable[[TripletLike], TripletLike] | None, optional): + optional preprocessing function for triplets. Defaults to None. + path (str, optional): path to save resulting stores. Defaults to ''. + n_parts (int, optional): Number of partitons to store in. + Defaults to 1. + node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into node encoding method. Defaults to None. + edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into edge encoding method. Defaults to None. + + Returns: + RemoteGraphBackendLoader: Loader to load RAG backend from disk or + memory. + """ + # Will return attribute errors for missing attributes + if not issubclass(graph_db, ConvertableGraphStore): + getattr(graph_db, "from_data") + getattr(graph_db, "from_hetero_data") + getattr(graph_db, "from_partition") + elif not issubclass(feature_db, ConvertableFeatureStore): + getattr(feature_db, "from_data") + getattr(feature_db, "from_hetero_data") + getattr(feature_db, "from_partition") + + # Resolve callable methods + node_method_kwargs = node_method_kwargs \ + if node_method_kwargs is not None else dict() + + edge_embedding_model = edge_embedding_model \ + if edge_embedding_model is not None else node_embedding_model + edge_method_to_call = edge_method_to_call \ + if edge_method_to_call is not None else node_method_to_call + edge_method_kwargs = edge_method_kwargs \ + if edge_method_kwargs is not None else node_method_kwargs + + # These will return AttributeErrors if they don't exist + node_model = getattr(node_embedding_model, node_method_to_call) + edge_model = getattr(edge_embedding_model, edge_method_to_call) + + indexer = LargeGraphIndexer.from_triplets(triplets, + pre_transform=pre_transform) + + node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) + indexer.add_node_feature('x', node_feats) + + edge_feats = edge_model( + indexer.get_unique_edge_features(feature_name=EDGE_RELATION), + **edge_method_kwargs) + indexer.add_edge_feature(new_feature_name="edge_attr", + new_feature_vals=edge_feats, + map_from_feature=EDGE_RELATION) + + data = indexer.to_data(node_feature_name='x', + edge_feature_name='edge_attr') + + if n_parts == 1: + torch.save(data, path) + return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, + feature_db) + else: + partitioner = Partitioner(data=data, num_parts=n_parts, root=path) + partitioner.generate_partition() + return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, + graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py new file mode 100644 index 000000000000..e01e9e59bb88 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_feature_store.py @@ -0,0 +1,189 @@ +import gc +from collections.abc import Iterable, Iterator +from typing import Any, Dict, Optional, Type, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torchmetrics.functional import pairwise_cosine_similarity + +from torch_geometric.data import Data, HeteroData +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.nn.pool import ApproxMIPSKNNIndex +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +# NOTE: Only compatible with Homogeneous graphs for now +class KNNRAGFeatureStore(LocalFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.enc_model = enc_model(*args, **kwargs).to(self.device) + self.enc_model.eval() + self.model_kwargs = \ + model_kwargs if model_kwargs is not None else dict() + super().__init__() + + @property + def x(self) -> Tensor: + return self.get_tensor(group_name=None, attr_name='x') + + @property + def edge_attr(self) -> Tensor: + return self.get_tensor(group_name=(None, None), attr_name='edge_attr') + + def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: + result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) + topk = min(k_nodes, len(self.x)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: + result = next(self._retrieve_seed_edges_batch([query], k_edges)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + + prizes = pairwise_cosine_similarity(query_enc, + self.edge_attr.to(self.device)) + topk = min(k_edges, len(self.edge_attr)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + + if isinstance(sample, HeteroSamplerOutput): + raise NotImplementedError + + # NOTE: torch_geometric.loader.utils.filter_custom_store can be used + # here if it supported edge features + node_id = sample.node + edge_id = sample.edge + edge_index = torch.stack((sample.row, sample.col), dim=0) + x = self.x[node_id] + edge_attr = self.edge_attr[edge_id] + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + node_idx=node_id, edge_idx=edge_id) + + +# TODO: Refactor because composition >> inheritance + + +def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, + device: torch.device, batch_size: int = 2**20): + """Add new features to the existing KNN index in batches. + + Args: + knn_index (ApproxMIPSKNNIndex): Index to add features to. + emb (Tensor): Embeddings to add. + device (torch.device): Device to store in + batch_size (int, optional): Batch size to iterate by. + Defaults to 2**20, which equates to 4GB if working with + 1024 dim floats. + """ + for i in range(0, emb.size(0), batch_size): + if emb.size(0) - i >= batch_size: + emb_batch = emb[i:i + batch_size].to(device) + else: + emb_batch = emb[i:].to(device) + knn_index.add(emb_batch) + + +class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + # TODO: Add kwargs for approx KNN to parameters here. + super().__init__(enc_model, model_kwargs, *args, **kwargs) + self.node_knn_index = None + self.edge_knn_index = None + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.node_knn_index is None: + self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.node_knn_index, self.x, + self.device) + + output = self.node_knn_index.search(query_enc, k=k_nodes) + yield from output.index + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.edge_knn_index is None: + self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, + self.device) + + output = self.edge_knn_index.search(query_enc, k=k_edges) + yield from output.index + + +# TODO: These two classes should be refactored +class SentenceTransformerFeatureStore(KNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) + + +class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py new file mode 100644 index 000000000000..896fbd7598b1 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_generate.py @@ -0,0 +1,139 @@ +# %% +import argparse +from itertools import chain +from typing import Tuple + +import pandas as pd +import torch +import tqdm +from rag_backend_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore + +from torch_geometric.data import Data +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +# %% +parser = argparse.ArgumentParser( + description="""Generate new WebQSP subgraphs\n""" + + """NOTE: Evaluating with smaller samples may result in""" + + """ poorer performance for the trained models compared""" + + """ to untrained models.""") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--use_pcst", action="store_true") +parser.add_argument("--num_samples", type=int, default=4700) +parser.add_argument("--out_file", default="subg_results.pt") +args = parser.parse_args() + +# %% +ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True, + force_reload=True) + +# %% +triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset) + +# %% +questions = ds.raw_dataset['question'] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]: + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + desc = ( + textual_nodes.to_csv(index=False) + "\n" + + textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"])) + graph["desc"] = desc + return graph + + +transform = apply_retrieval_via_pcst \ + if args.use_pcst else apply_retrieval_with_text + +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5}, + seed_edges_kwargs={"k_edges": 5}, + sampler_kwargs={"num_neighbors": [50] * 2}, + local_filter=transform) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% +retrieval_stats = {"precision": [], "recall": [], "accuracy": []} +subgs = [] +node_len = [] +edge_len = [] +for subg in tqdm.tqdm(query_loader.query(q) for q in questions): + subgs.append(subg) + node_len.append(subg['x'].shape[0]) + edge_len.append(subg['edge_attr'].shape[0]) + +for i, subg in enumerate(subgs): + subg['question'] = questions[i] + subg['label'] = ds[i]['label'] + +pd.DataFrame.from_dict(retrieval_stats).to_csv( + args.out_file.split('.')[0] + '_metadata.csv') +torch.save(subgs, args.out_file) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py new file mode 100644 index 000000000000..48473f287233 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_graph_store.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import FeatureStore +from torch_geometric.distributed import LocalGraphStore +from torch_geometric.sampler import ( + HeteroSamplerOutput, + NeighborSampler, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.sampler.neighbor_sampler import NumNeighborsType +from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes + + +class NeighborSamplingRAGGraphStore(LocalGraphStore): + def __init__(self, feature_store: Optional[FeatureStore] = None, + num_neighbors: NumNeighborsType = [1], **kwargs): + self.feature_store = feature_store + self._num_neighbors = num_neighbors + self.sample_kwargs = kwargs + self._sampler_is_initialized = False + super().__init__() + + def _init_sampler(self): + if self.feature_store is None: + raise AttributeError("Feature store not registered yet.") + self.sampler = NeighborSampler(data=(self.feature_store, self), + num_neighbors=self._num_neighbors, + **self.sample_kwargs) + self._sampler_is_initialized = True + + def register_feature_store(self, feature_store: FeatureStore): + self.feature_store = feature_store + self._sampler_is_initialized = False + + def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: + ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) + self._sampler_is_initialized = False + return ret + + @property + def edge_index(self): + return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) + + def put_edge_index(self, edge_index: EdgeTensorType, *args, + **kwargs) -> bool: + ret = super().put_edge_index(edge_index, *args, **kwargs) + # HACK + self.edge_idx_args = args + self.edge_idx_kwargs = kwargs + self._sampler_is_initialized = False + return ret + + @property + def num_neighbors(self): + return self._num_neighbors + + @num_neighbors.setter + def num_neighbors(self, num_neighbors: NumNeighborsType): + self._num_neighbors = num_neighbors + if hasattr(self, 'sampler'): + self.sampler.num_neighbors = num_neighbors + + def sample_subgraph( + self, seed_nodes: InputNodes, seed_edges: InputEdges, + num_neighbors: Optional[NumNeighborsType] = None + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample the graph starting from the given nodes and edges using the + in-built NeighborSampler. + + Args: + seed_nodes (InputNodes): Seed nodes to start sampling from. + seed_edges (InputEdges): Seed edges to start sampling from. + num_neighbors (Optional[NumNeighborsType], optional): Parameters + to determine how many hops and number of neighbors per hop. + Defaults to None. + + Returns: + Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput + for the input. + """ + if not self._sampler_is_initialized: + self._init_sampler() + if num_neighbors is not None: + self.num_neighbors = num_neighbors + + # FIXME: Right now, only input nodes/edges as tensors are be supported + if not isinstance(seed_nodes, Tensor): + raise NotImplementedError + if not isinstance(seed_edges, Tensor): + raise NotImplementedError + device = seed_nodes.device + + # TODO: Call sample_from_edges for seed_edges + # Turning them into nodes for now. + seed_edges = self.edge_index.to(device).T[seed_edges.to( + device)].reshape(-1) + seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) + + seed_nodes = seed_nodes.unique().contiguous() + node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) + out = self.sampler.sample_from_nodes(node_sample_input) + + return out diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md new file mode 100644 index 000000000000..ff43b16a2c05 --- /dev/null +++ b/examples/llm/multihop_rag/README.md @@ -0,0 +1,9 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. | +| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. | +| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. | + +NOTE: Performance of GRetriever on this dataset has not been evaluated. diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh new file mode 100644 index 000000000000..3c1970d39440 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_download.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Wikidata5m + +wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz" +tar -xvf "wikidata5m_alias.tar.gz" +wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz" +gzip -d "wikidata5m_all_triplet.txt.gz" -f + +# 2Multihopqa +wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip" +unzip -o "data_ids_april7.zip" diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py new file mode 100644 index 000000000000..46052bdf1b15 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_preprocess.py @@ -0,0 +1,276 @@ +"""Example workflow for downloading and assembling a multihop QA dataset.""" + +import argparse +import json +from subprocess import call + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import LargeGraphIndexer + +# %% [markdown] +# # Encoding A Large Knowledge Graph Part 2 + +# %% [markdown] +# In this notebook, we will continue where we left off by building a new +# multi-hop QA dataset based on Wikidata. + +# %% [markdown] +# ## Example 2: Building a new Dataset from Questions and an already-existing +# Knowledge Graph + +# %% [markdown] +# ### Motivation + +# %% [markdown] +# One potential application of knowledge graph structural encodings is +# capturing the relationships between different entities that are multiple +# hops apart. This can be challenging for an LLM to recognize from prepended +# graph information. Here's a motivating example (credit to @Rishi Puri): + +# %% [markdown] +# In this example, the question can only be answered by reasoning about the +# relationships between the entities in the knowledge graph. + +# %% [markdown] +# ### Building a Multi-Hop QA Dataset + +# %% [markdown] +# To start, we need to download the raw data of a knowledge graph. +# In this case, we use WikiData5M +# ([Wang et al] +# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)). +# Here we download the raw triplets and their entity codes. Information about +# this dataset can be found +# [here](https://deepgraphlearning.github.io/project/wikidata5m). + +# %% [markdown] +# The following download contains the ID to plaintext mapping for all the +# entities and relations in the knowledge graph: + +rv = call("./multihop_download.sh") + +# %% [markdown] +# To start, we are going to preprocess the knowledge graph to substitute each +# of the entity/relation codes with their plaintext aliases. This makes it +# easier to use a pre-trained textual encoding model to create triplet +# embeddings, as such a model likely won't understand how to properly embed +# the entity codes. + +# %% + +# %% +parser = argparse.ArgumentParser(description="Preprocess wikidata5m") +parser.add_argument("--n_triplets", type=int, default=-1) +args = parser.parse_args() + +# %% +# Substitute entity codes with their aliases +# Picking the first alias for each entity (rather arbitrarily) +alias_map = {} +rel_alias_map = {} +for line in open('wikidata5m_entity.txt'): + parts = line.strip().split('\t') + entity_id = parts[0] + aliases = parts[1:] + alias_map[entity_id] = aliases[0] +for line in open('wikidata5m_relation.txt'): + parts = line.strip().split('\t') + relation_id = parts[0] + relation_name = parts[1] + rel_alias_map[relation_id] = relation_name + +# %% +full_graph = [] +missing_total = 0 +total = 0 +limit = None if args.n_triplets == -1 else args.n_triplets +i = 0 + +for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')): + if limit is not None and i >= limit: + break + src, rel, dst = line.strip().split('\t') + if src not in alias_map: + missing_total += 1 + if dst not in alias_map: + missing_total += 1 + if rel not in rel_alias_map: + missing_total += 1 + total += 3 + full_graph.append([ + alias_map.get(src, src), + rel_alias_map.get(rel, rel), + alias_map.get(dst, dst) + ]) + i += 1 +print(f"Missing aliases: {missing_total}/{total}") + +# %% [markdown] +# Now `full_graph` represents the knowledge graph triplets in +# understandable plaintext. + +# %% [markdown] +# Next, we need a set of multi-hop questions that the Knowledge Graph will +# provide us with context for. We utilize a subset of +# [HotPotQA](https://hotpotqa.github.io/) +# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called +# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop) +# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)), +# which includes a subgraph of entities that serve as the ground truth +# justification for answering each multi-hop question: + +# %% +with open('train.json') as f: + train_data = json.load(f) +train_df = pd.DataFrame(train_data) +train_df['split_type'] = 'train' + +with open('dev.json') as f: + dev_data = json.load(f) +dev_df = pd.DataFrame(dev_data) +dev_df['split_type'] = 'dev' + +with open('test.json') as f: + test_data = json.load(f) +test_df = pd.DataFrame(test_data) +test_df['split_type'] = 'test' + +df = pd.concat([train_df, dev_df, test_df]) + +# %% [markdown] +# Now we need to extract the subgraphs + +# %% +df['graph_size'] = df['evidences_id'].apply(lambda row: len(row)) + +# %% [markdown] +# (Optional) We take only questions where the evidence graph is greater than +# 0. (Note: this gets rid of the test set): + +# %% +# df = df[df['graph_size'] > 0] + +# %% +refined_df = df[[ + '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type', + 'graph_size' +]] + +# %% [markdown] +# Checkpoint: + +# %% +refined_df.to_csv('wikimultihopqa_refined.csv', index=False) + +# %% [markdown] +# Now we need to check that all the entities mentioned in the question/answer +# set are also present in the Wikidata graph: + +# %% +relation_map = {} +with open('wikidata5m_relation.txt') as f: + for line in tqdm.tqdm(f): + parts = line.strip().split('\t') + for i in range(1, len(parts)): + if parts[i] not in relation_map: + relation_map[parts[i]] = [] + relation_map[parts[i]].append(parts[0]) + +# %% +entity_set = set() +with open('wikidata5m_entity.txt') as f: + for line in tqdm.tqdm(f): + entity_set.add(line.strip().split('\t')[0]) + +# %% +missing_entities = set() +missing_entity_idx = set() +for i, row in enumerate(refined_df.itertuples()): + for trip in row.evidences_id: + entities = trip[0], trip[2] + for entity in entities: + if entity not in entity_set: + # print( + # f'The following entity was not found in the KG: {entity}' + # ) + missing_entities.add(entity) + missing_entity_idx.add(i) + +# %% [markdown] +# Right now, we drop the missing entity entries. Additional preprocessing can +# be done here to resolve the entity/relation collisions, but that is out of +# the scope for this notebook. + +# %% +# missing relations are ok, but missing entities cannot be mapped to +# plaintext, so they should be dropped. +refined_df.reset_index(inplace=True, drop=True) + +# %% +cleaned_df = refined_df.drop(missing_entity_idx) + +# %% [markdown] +# Now we save the resulting graph and questions/answers dataset: + +# %% +cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False) + +# %% + +# %% +torch.save(full_graph, 'wikimultihopqa_full_graph.pt') + +# %% [markdown] +# ### Question: How do we extract a contextual subgraph for a given query? + +# %% [markdown] +# The chosen retrieval algorithm is a critical component in the pipeline for +# affecting RAG performance. In the next section (1), we will demonstrate a +# naive method of retrieval for a large knowledge graph, and how to apply it +# to this dataset along with WebQSP. + +# %% [markdown] +# ### Preparing a Textualized Graph for LLM + +# %% [markdown] +# For now however, we need to prepare the graph data to be used as a plaintext +# prefix to the LLM. In order to do this, we want to prompt the LLM to use the +# unique nodes, and unique edge triplets of a given subgraph. In order to do +# this, we prepare a unique indexed node df and edge df for the knowledge +# graph now. This process occurs trivially with the LargeGraphIndexer: + +# %% + +# %% +indexer = LargeGraphIndexer.from_triplets(full_graph) + +# %% +# Node DF +textual_nodes = pd.DataFrame.from_dict( + {"node_attr": indexer.get_node_features()}) +textual_nodes["node_id"] = textual_nodes.index +textual_nodes = textual_nodes[["node_id", "node_attr"]] + +# %% [markdown] +# Notice how LargeGraphIndexer ensures that there are no duplicate indices: + +# %% +# Edge DF +textual_edges = pd.DataFrame(indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) +textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]] +textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]] + +# %% [markdown] +# Note: The edge table refers to each node by its index in the node table. +# We will see how this gets utilized later when indexing a subgraph. + +# %% [markdown] +# Now we can save the result + +# %% +textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False) +textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False) diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py new file mode 100644 index 000000000000..de93a9e75dd1 --- /dev/null +++ b/examples/llm/multihop_rag/rag_generate_multihop.py @@ -0,0 +1,88 @@ +# %% +import argparse +import sys +from typing import Tuple + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import Data +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +sys.path.append('..') + +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerApproxFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +parser = argparse.ArgumentParser( + description="Generate new multihop dataset for rag") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--num_samples", type=int) +args = parser.parse_args() + +# %% +triplets = torch.load('wikimultihopqa_full_graph.pt') + +# %% +df = pd.read_csv('wikimultihopqa_cleaned.csv') +questions = df['question'][:args.num_samples] +labels = df['answer'][:args.num_samples] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerApproxFeatureStore).load() + +# %% + +all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv') +all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv') + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +# %% +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, + seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": [40] * 3}, + local_filter=apply_retrieval_via_pcst) + +# %% +subgs = [] +for q, l in tqdm.tqdm(zip(questions, labels)): + subg = query_loader.query(q) + subg['question'] = q + subg['label'] = l + subgs.append(subg) + +torch.save(subgs, 'subg_results.pt') diff --git a/examples/llm/nvtx_examples/README.md b/examples/llm/nvtx_examples/README.md new file mode 100644 index 000000000000..aa4f070d9824 --- /dev/null +++ b/examples/llm/nvtx_examples/README.md @@ -0,0 +1,7 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| [`nvtx_run.sh`](./nvtx_run.sh) | Runs nsys profiler on a given Python file that contains NVTX calls. | +| [`nvtx_rag_backend_example.py`](./nvtx_rag_backend_example.py) | Example script for nsys profiling a RAG Backend such as that used in [`rag_generate.py`](../g_retriever_utils/rag_generate.py). | +| [`nvtx_webqsp_example.py`](./nvtx_webqsp_example.py) | Example script for nsys profiling the WebQSP dataset. | diff --git a/examples/llm/nvtx_examples/nvtx_rag_backend_example.py b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py new file mode 100644 index 000000000000..b30e34b8c7b1 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py @@ -0,0 +1,144 @@ +# %% +import argparse +import sys +from itertools import chain +from typing import Tuple + +import torch + +from torch_geometric.data import Data, get_features_for_triplets_groups +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import rag_loader +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.profile.nvtx import nvtxit + +sys.path.append('..') +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +# Patch FeatureStore and GraphStore + +SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_nodes) +SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_edges) +SentenceTransformerFeatureStore.load_subgraph = nvtxit()( + SentenceTransformerFeatureStore.load_subgraph) +NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()( + NeighborSamplingRAGGraphStore.sample_subgraph) +rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +# %% +ds = WebQSPDataset("small_ds_1", force_reload=True, limit=10) + +# %% +triplets = list(chain.from_iterable(d['graph'] for d in ds.raw_dataset)) + +# %% +questions = ds.raw_dataset['question'] + +# %% +ground_truth_graphs = get_features_for_triplets_groups( + ds.indexer, (d['graph'] for d in ds.raw_dataset), + pre_transform=preprocess_triplet) +num_edges = len(ds.indexer._edges) + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer('sentence-transformers/all-roberta-large-v1').to( + device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +@nvtxit() +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return graph + + +# %% +query_loader = rag_loader.RAGQueryLoader( + data=(fs, gs), seed_nodes_kwargs={"k_nodes": + 10}, seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": + [40] * 10}, local_filter=apply_retrieval_via_pcst) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% + + +@nvtxit() +def _run_eval(): + for subg, gt in zip((query_loader.query(q) for q in questions), + ground_truth_graphs): + print(check_retrieval_accuracy(subg, gt, num_edges), + check_retrieval_precision(subg, gt), + check_retrieval_recall(subg, gt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + _run_eval() + else: + _run_eval() diff --git a/examples/llm/nvtx_examples/nvtx_run.sh b/examples/llm/nvtx_examples/nvtx_run.sh new file mode 100644 index 000000000000..4c6fce7c8224 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_run.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# Check if the user provided a Python file +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Check if the provided file exists +if [[ ! -f "$1" ]]; then + echo "Error: File '$1' does not exist." + exit 1 +fi + +# Check if the provided file is a Python file +if [[ ! "$1" == *.py ]]; then + echo "Error: '$1' is not a Python file." + exit 1 +fi + +# Get the base name of the Python file +python_file=$(basename "$1") + +# Run nsys profile on the Python file +nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" + +echo "Profile data saved as profile_${python_file%.py}.nsys-rep" diff --git a/examples/llm/nvtx_examples/nvtx_webqsp_example.py b/examples/llm/nvtx_examples/nvtx_webqsp_example.py new file mode 100644 index 000000000000..5a9aad27f1c0 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_webqsp_example.py @@ -0,0 +1,22 @@ +import argparse + +import torch + +from torch_geometric.datasets import web_qsp_dataset +from torch_geometric.profile import nvtxit + +# Apply Patches +web_qsp_dataset.retrieval_via_pcst = nvtxit()( + web_qsp_dataset.retrieval_via_pcst) +web_qsp_dataset.WebQSPDataset.process = nvtxit()( + web_qsp_dataset.WebQSPDataset.process) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') + else: + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') diff --git a/test/data/test_large_graph_indexer.py b/test/data/test_large_graph_indexer.py new file mode 100644 index 000000000000..b98fe7d7ddbf --- /dev/null +++ b/test/data/test_large_graph_indexer.py @@ -0,0 +1,177 @@ +import random +import string +from typing import List + +import pytest +import torch + +from torch_geometric.data import ( + Data, + LargeGraphIndexer, + TripletLike, + get_features_for_triplets, +) +from torch_geometric.data.large_graph_indexer import ( + EDGE_PID, + EDGE_RELATION, + NODE_PID, +) +from torch_geometric.typing import WITH_PT20 + +# create possible nodes and edges for graph +strkeys = string.ascii_letters + string.digits +NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)}) +EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)}) + + +def featurize(s: str) -> int: + return int.from_bytes(s.encode(), 'little') + + +def sample_triplets(amount: int = 1) -> List[TripletLike]: + trips = [] + for i in range(amount): + h, t = random.sample(NODE_POOL, k=2) + r = random.sample(EDGE_POOL, k=1)[0] + trips.append(tuple([h, r, t])) + return trips + + +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return h.lower(), r, t.lower() + + +def test_basic_collate(): + graphs = [sample_triplets(1000) for i in range(2)] + + indexer_0 = LargeGraphIndexer.from_triplets( + graphs[0], pre_transform=preprocess_triplet) + indexer_1 = LargeGraphIndexer.from_triplets( + graphs[1], pre_transform=preprocess_triplet) + + big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1]) + + assert len(indexer_0._nodes) + len( + indexer_1._nodes) - len(indexer_0._nodes.keys() + & indexer_1._nodes.keys()) == len( + big_indexer._nodes) + assert len(indexer_0._edges) + len( + indexer_1._edges) - len(indexer_0._edges.keys() + & indexer_1._edges.keys()) == len( + big_indexer._edges) + + assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes) + assert len(set(big_indexer._edges.values())) == len(big_indexer._edges) + + for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()): + assert big_indexer.node_attr[NODE_PID][ + big_indexer._nodes[node]] == node + + +def test_large_graph_index(): + graphs = [sample_triplets(1000) for i in range(100)] + + # Preprocessing of trips lowercases nodes but not edges + node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + def encode_graph_from_trips(triplets: List[TripletLike]) -> Data: + seen_nodes = dict() + edge_attrs = list() + edge_idx = [] + for trip in triplets: + trip = preprocess_triplet(trip) + h, r, t = trip + seen_nodes[h] = len( + seen_nodes) if h not in seen_nodes else seen_nodes[h] + seen_nodes[t] = len( + seen_nodes) if t not in seen_nodes else seen_nodes[t] + edge_attrs.append(edge_feature_vecs[r]) + edge_idx.append((seen_nodes[h], seen_nodes[t])) + + x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()]) + edge_idx = torch.LongTensor(edge_idx).T + edge_attrs = torch.Tensor(edge_attrs) + return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs) + + naive_graph_ds = [ + encode_graph_from_trips(triplets=trips) for trips in graphs + ] + + indexer = LargeGraphIndexer.collate([ + LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet) + for g in graphs + ]) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + large_graph_ds = [ + get_features_for_triplets(indexer=indexer, triplets=g, + node_feature_name='x', + edge_feature_name='edge_attr', + pre_transform=preprocess_triplet) + for g in graphs + ] + + for ds in large_graph_ds: + assert NODE_PID in ds + assert EDGE_PID in ds + assert "node_idx" in ds + assert "edge_idx" in ds + + def results_are_close_enough(ground_truth: Data, new_method: Data, + thresh=.99): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all( + torch.isclose(tensor1.sort()[0], + tensor2.sort()[0]) > thresh) + + def _graphs_are_same(tensor1, tensor2): + if not WITH_PT20: + pytest.skip( + "This test requires a PyG version with NetworkX as a " + + "dependency.") + import networkx as nx + return nx.weisfeiler_lehman_graph_hash(nx.Graph( + tensor1.T)) == nx.weisfeiler_lehman_graph_hash( + nx.Graph(tensor2.T)) + return True + return _sorted_tensors_are_close( + ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close( + ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same( + ground_truth.edge_index, new_method.edge_index) + + for dsets in zip(naive_graph_ds, large_graph_ds): + assert results_are_close_enough(*dsets) + + +def test_save_load(tmp_path): + graph = sample_triplets(1000) + + node_feature_vecs = {s: featurize(s) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + indexer = LargeGraphIndexer.from_triplets(graph) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + + indexer.save(str(tmp_path)) + assert indexer == LargeGraphIndexer.from_disk(str(tmp_path)) diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py index 899e70730cc9..24a74d1b6f6e 100644 --- a/test/nn/models/test_g_retriever.py +++ b/test/nn/models/test_g_retriever.py @@ -51,3 +51,52 @@ def test_g_retriever() -> None: # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_g_retriever_many_tokens() -> None: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.float16, + ) + + gnn = GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=1024, + num_layers=2, + heads=4, + norm='batch_norm', + ) + + model = GRetriever( + llm=llm, + gnn=gnn, + mlp_out_channels=2048, + mlp_out_tokens=2, + ) + assert str(model) == ('GRetriever(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' gnn=GAT(1024, 1024, num_layers=2),\n' + ')') + + x = torch.randn(10, 1024) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 1024) + batch = torch.zeros(x.size(0), dtype=torch.long) + + question = ["Is PyG the best open-source GNN library?"] + label = ["yes!"] + + # Test train: + loss = model(question, x, edge_index, batch, label, edge_attr) + assert loss >= 0 + + # Test inference: + pred = model.inference(question, x, edge_index, batch, edge_attr) + assert len(pred) == 1 diff --git a/test/profile/test_nvtx.py b/test/profile/test_nvtx.py new file mode 100644 index 000000000000..56e28a9c2e59 --- /dev/null +++ b/test/profile/test_nvtx.py @@ -0,0 +1,136 @@ +from unittest.mock import call, patch + +from torch_geometric.profile import nvtxit + + +def _setup_mock(torch_cuda_mock): + torch_cuda_mock.is_available.return_value = True + torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None + torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None + return torch_cuda_mock + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_base(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit() + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_rename(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit('a_nvtx') + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('a_nvtx_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_iters(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_iters=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_b_0'), call('call_a_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_warmups(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_warmups=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_1') + ] diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index 821ef9c5c063..fee215b1a357 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -16,6 +16,7 @@ from .makedirs import makedirs from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz +from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups from torch_geometric.lazy_loader import LazyLoader @@ -27,6 +28,8 @@ 'Dataset', 'InMemoryDataset', 'OnDiskDataset', + 'LargeGraphIndexer', + 'TripletLike', ] remote_backend_classes = [ @@ -50,6 +53,8 @@ 'extract_zip', 'extract_bz2', 'extract_gz', + 'get_features_for_triplets', + "get_features_for_triplets_groups", ] __all__ = data_classes + remote_backend_classes + helper_functions diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py new file mode 100644 index 000000000000..0644e2543303 --- /dev/null +++ b/torch_geometric/data/large_graph_indexer.py @@ -0,0 +1,677 @@ +import os +import pickle as pkl +import shutil +from dataclasses import dataclass +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import Data +from torch_geometric.typing import WITH_PT24 + +TripletLike = Tuple[Hashable, Hashable, Hashable] + +KnowledgeGraphLike = Iterable[TripletLike] + + +def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: + return list(dict.fromkeys(values)) + + +# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum? + +NODE_PID = "pid" + +NODE_KEYS = {NODE_PID} + +EDGE_PID = "e_pid" +EDGE_HEAD = "h" +EDGE_RELATION = "r" +EDGE_TAIL = "t" +EDGE_INDEX = "edge_idx" + +EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX} + +FeatureValueType = Union[Sequence[Any], Tensor] + + +@dataclass +class MappedFeature: + name: str + values: FeatureValueType + + def __eq__(self, value: "MappedFeature") -> bool: + eq = self.name == value.name + if isinstance(self.values, torch.Tensor): + eq &= torch.equal(self.values, value.values) + else: + eq &= self.values == value.values + return eq + + +if WITH_PT24: + torch.serialization.add_safe_globals([MappedFeature]) + + +class LargeGraphIndexer: + """For a dataset that consists of mulitiple subgraphs that are assumed to + be part of a much larger graph, collate the values into a large graph store + to save resources. + """ + def __init__( + self, + nodes: Iterable[Hashable], + edges: KnowledgeGraphLike, + node_attr: Optional[Dict[str, List[Any]]] = None, + edge_attr: Optional[Dict[str, List[Any]]] = None, + ) -> None: + r"""Constructs a new index that uniquely catalogs each node and edge + by id. Not meant to be used directly. + + Args: + nodes (Iterable[Hashable]): Node ids in the graph. + edges (KnowledgeGraphLike): Edge ids in the graph. + node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node + attribute name and list of their values in order of unique node + ids. Defaults to None. + edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge + attribute name and list of their values in order of unique edge + ids. Defaults to None. + """ + self._nodes: Dict[Hashable, int] = dict() + self._edges: Dict[TripletLike, int] = dict() + + self._mapped_node_features: Set[str] = set() + self._mapped_edge_features: Set[str] = set() + + if len(nodes) != len(set(nodes)): + raise AttributeError("Nodes need to be unique") + if len(edges) != len(set(edges)): + raise AttributeError("Edges need to be unique") + + if node_attr is not None: + # TODO: Validity checks btw nodes and node_attr + self.node_attr = node_attr + if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS: + raise AttributeError( + "Invalid node_attr object. Missing " + + f"{NODE_KEYS - set(self.node_attr.keys())}") + elif self.node_attr[NODE_PID] != nodes: + raise AttributeError( + "Nodes provided do not match those in node_attr") + else: + self.node_attr = dict() + self.node_attr[NODE_PID] = nodes + + for i, node in enumerate(self.node_attr[NODE_PID]): + self._nodes[node] = i + + if edge_attr is not None: + # TODO: Validity checks btw edges and edge_attr + self.edge_attr = edge_attr + + if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS: + raise AttributeError( + "Invalid edge_attr object. Missing " + + f"{EDGE_KEYS - set(self.edge_attr.keys())}") + elif self.node_attr[EDGE_PID] != edges: + raise AttributeError( + "Edges provided do not match those in edge_attr") + + else: + self.edge_attr = dict() + for default_key in EDGE_KEYS: + self.edge_attr[default_key] = list() + self.edge_attr[EDGE_PID] = edges + + for i, tup in enumerate(edges): + h, r, t = tup + self.edge_attr[EDGE_HEAD].append(h) + self.edge_attr[EDGE_RELATION].append(r) + self.edge_attr[EDGE_TAIL].append(t) + self.edge_attr[EDGE_INDEX].append( + (self._nodes[h], self._nodes[t])) + + for i, tup in enumerate(edges): + self._edges[tup] = i + + @classmethod + def from_triplets( + cls, + triplets: KnowledgeGraphLike, + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + ) -> "LargeGraphIndexer": + r"""Generate a new index from a series of triplets that represent edge + relations between nodes. + Formatted like (source_node, edge, dest_node). + + Args: + triplets (KnowledgeGraphLike): Series of triplets representing + knowledge graph relations. + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function to apply to triplets. + Defaults to None. + + Returns: + LargeGraphIndexer: Index of unique nodes and edges. + """ + # NOTE: Right now assumes that all trips can be loaded into memory + nodes = set() + edges = set() + + if pre_transform is not None: + + def apply_transform( + trips: KnowledgeGraphLike) -> Iterator[TripletLike]: + for trip in trips: + yield pre_transform(trip) + + triplets = apply_transform(triplets) + + for h, r, t in triplets: + + for node in (h, t): + nodes.add(node) + + edge_idx = (h, r, t) + edges.add(edge_idx) + + return cls(list(nodes), list(edges)) + + @classmethod + def collate(cls, + graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer": + r"""Combines a series of large graph indexes into a single large graph + index. + + Args: + graphs (Iterable["LargeGraphIndexer"]): Indices to be + combined. + + Returns: + LargeGraphIndexer: Singular unique index for all nodes and edges + in input indices. + """ + # FIXME Needs to merge node attrs and edge attrs? + trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) + return cls.from_triplets(trips) + + def get_unique_node_features( + self, feature_name: str = NODE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific node attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to NODE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_node_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_node_features(feature_name)) + + except KeyError: + raise AttributeError( + f"Nodes do not have a feature called {feature_name}") + + def add_node_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = NODE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique node in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that + new feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to NODE_PID. + """ + if new_feature_name in self.node_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_node_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_node_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + "Expected encodings for {len(feature_keys)} unique features," + + f" but got {len(new_feature_vals)} encodings.") + + if map_from_feature == NODE_PID: + self.node_attr[new_feature_name] = new_feature_vals + else: + self.node_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_node_features.add(new_feature_name) + + def get_node_features( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get node feature values for a given set of unique node ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. Defaults + to NODE_PID. + pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + for. Defaults to None, which fetches all nodes. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_node_features: + values = self.node_attr[feature_name].values + else: + values = self.node_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_node_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_node_features_iter(feature_name, pids)) + + def get_node_features_iter( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_node_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.node_attr[NODE_PID] + + if feature_name in self._mapped_node_features: + feature_map_info = self.node_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_node_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._nodes[pid] + from_feature_val = self.node_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._nodes[pid] + if index_only: + yield idx + else: + yield self.node_attr[feature_name][idx] + + def get_unique_edge_features( + self, feature_name: str = EDGE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific edge attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to EDGE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_edge_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_edge_features(feature_name)) + except KeyError: + raise AttributeError( + f"Edges do not have a feature called {feature_name}") + + def add_edge_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = EDGE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique edge in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that new + feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to EDGE_PID. + """ + if new_feature_name in self.edge_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_edge_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_edge_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + f"Expected encodings for {len(feature_keys)} unique features, " + + f"but got {len(new_feature_vals)} encodings.") + + if map_from_feature == EDGE_PID: + self.edge_attr[new_feature_name] = new_feature_vals + else: + self.edge_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_edge_features.add(new_feature_name) + + def get_edge_features( + self, + feature_name: str = EDGE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get edge feature values for a given set of unique edge ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. + Defaults to EDGE_PID. + pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + for. Defaults to None, which fetches all edges. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_edge_features: + values = self.edge_attr[feature_name].values + else: + values = self.edge_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_edge_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_edge_features_iter(feature_name, pids)) + + def get_edge_features_iter( + self, + feature_name: str = EDGE_PID, + pids: Optional[KnowledgeGraphLike] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_edge_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.edge_attr[EDGE_PID] + + if feature_name in self._mapped_edge_features: + feature_map_info = self.edge_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_edge_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._edges[pid] + from_feature_val = self.edge_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._edges[pid] + if index_only: + yield idx + else: + yield self.edge_attr[feature_name][idx] + + def to_triplets(self) -> Iterator[TripletLike]: + return iter(self.edge_attr[EDGE_PID]) + + def save(self, path: str) -> None: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + with open(path + "/edges", "wb") as f: + pkl.dump(self._edges, f) + with open(path + "/nodes", "wb") as f: + pkl.dump(self._nodes, f) + + with open(path + "/mapped_edges", "wb") as f: + pkl.dump(self._mapped_edge_features, f) + with open(path + "/mapped_nodes", "wb") as f: + pkl.dump(self._mapped_node_features, f) + + node_attr_path = path + "/node_attr" + os.makedirs(node_attr_path, exist_ok=True) + for attr_name, vals in self.node_attr.items(): + torch.save(vals, node_attr_path + f"/{attr_name}.pt") + + edge_attr_path = path + "/edge_attr" + os.makedirs(edge_attr_path, exist_ok=True) + for attr_name, vals in self.edge_attr.items(): + torch.save(vals, edge_attr_path + f"/{attr_name}.pt") + + @classmethod + def from_disk(cls, path: str) -> "LargeGraphIndexer": + indexer = cls(list(), list()) + with open(path + "/edges", "rb") as f: + indexer._edges = pkl.load(f) + with open(path + "/nodes", "rb") as f: + indexer._nodes = pkl.load(f) + + with open(path + "/mapped_edges", "rb") as f: + indexer._mapped_edge_features = pkl.load(f) + with open(path + "/mapped_nodes", "rb") as f: + indexer._mapped_node_features = pkl.load(f) + + node_attr_path = path + "/node_attr" + for fname in os.listdir(node_attr_path): + full_fname = f"{node_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.node_attr[key] = torch.load(full_fname) + + edge_attr_path = path + "/edge_attr" + for fname in os.listdir(edge_attr_path): + full_fname = f"{edge_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.edge_attr[key] = torch.load(full_fname) + + return indexer + + def to_data(self, node_feature_name: str, + edge_feature_name: Optional[str] = None) -> Data: + """Return a Data object containing all the specified node and + edge features and the graph. + + Args: + node_feature_name (str): Feature to use for nodes + edge_feature_name (Optional[str], optional): Feature to use for + edges. Defaults to None. + + Returns: + Data: Data object containing the specified node and + edge features and the graph. + """ + x = torch.Tensor(self.get_node_features(node_feature_name)) + node_id = torch.LongTensor(range(len(x))) + + edge_index = torch.t( + torch.LongTensor(self.get_edge_features(EDGE_INDEX))) + + edge_attr = (self.get_edge_features(edge_feature_name) + if edge_feature_name is not None else None) + edge_id = torch.LongTensor(range(len(edge_attr))) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + edge_id=edge_id, node_id=node_id) + + def __eq__(self, value: "LargeGraphIndexer") -> bool: + eq = True + eq &= self._nodes == value._nodes + eq &= self._edges == value._edges + eq &= self.node_attr.keys() == value.node_attr.keys() + eq &= self.edge_attr.keys() == value.edge_attr.keys() + eq &= self._mapped_node_features == value._mapped_node_features + eq &= self._mapped_edge_features == value._mapped_edge_features + + for k in self.node_attr: + eq &= isinstance(self.node_attr[k], type(value.node_attr[k])) + if isinstance(self.node_attr[k], torch.Tensor): + eq &= torch.equal(self.node_attr[k], value.node_attr[k]) + else: + eq &= self.node_attr[k] == value.node_attr[k] + for k in self.edge_attr: + eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k])) + if isinstance(self.edge_attr[k], torch.Tensor): + eq &= torch.equal(self.edge_attr[k], value.edge_attr[k]) + else: + eq &= self.edge_attr[k] == value.edge_attr[k] + return eq + + +def get_features_for_triplets_groups( + indexer: LargeGraphIndexer, + triplet_groups: Iterable[KnowledgeGraphLike], + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Iterator[Data]: + """Given an indexer and a series of triplet groups (like a dataset), + retrieve the specified node and edge features for each triplet from the + index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of + triplets to fetch features for + node_feature_name (str, optional): Node feature to fetch. + Defaults to "x". + edge_feature_name (str, optional): edge feature to fetch. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing to perform on triplets. + Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Yields: + Iterator[Data]: For each triplet group, yield a data object containing + the unique graph and features from the index. + """ + if pre_transform is not None: + + def apply_transform(trips): + for trip in trips: + yield pre_transform(tuple(trip)) + + # TODO: Make this safe for large amounts of triplets? + triplet_groups = (list(apply_transform(triplets)) + for triplets in triplet_groups) + + node_keys = [] + edge_keys = [] + edge_index = [] + + for triplets in tqdm(triplet_groups, disable=not verbose): + small_graph_indexer = LargeGraphIndexer.from_triplets( + triplets, pre_transform=pre_transform) + + node_keys.append(small_graph_indexer.get_node_features()) + edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets)) + edge_index.append( + small_graph_indexer.get_edge_features(EDGE_INDEX, triplets)) + + node_feats = indexer.get_node_features(feature_name=node_feature_name, + pids=chain.from_iterable(node_keys)) + edge_feats = indexer.get_edge_features(feature_name=edge_feature_name, + pids=chain.from_iterable(edge_keys)) + + last_node_idx, last_edge_idx = 0, 0 + for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index): + nlen, elen = len(nkeys), len(ekeys) + x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen]) + last_node_idx += len(nkeys) + + edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx + + elen]) + last_edge_idx += len(ekeys) + + edge_idx = torch.LongTensor(eidx).T + + data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) + data_obj[NODE_PID] = node_keys + data_obj[EDGE_PID] = edge_keys + data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys] + data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys] + + yield data_obj + + +def get_features_for_triplets( + indexer: LargeGraphIndexer, + triplets: KnowledgeGraphLike, + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Data: + """For a given set of triplets retrieve a Data object containing the + unique graph and features from the index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplets (KnowledgeGraphLike): Triplets to fetch features for + node_feature_name (str, optional): Feature to use for node features. + Defaults to "x". + edge_feature_name (str, optional): Feature to use for edge features. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function for triplets. Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Returns: + Data: Data object containing the unique graph and features from the + index for the given triplets. + """ + gen = get_features_for_triplets_groups(indexer, [triplets], + node_feature_name, + edge_feature_name, pre_transform, + verbose) + return next(gen) diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 266f498a113b..7e83c35befb6 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,6 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin +from .rag_loader import RAGQueryLoader __all__ = classes = [ 'DataLoader', @@ -50,6 +51,7 @@ 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', + 'RAGQueryLoader', ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py new file mode 100644 index 000000000000..33d6cf0e868e --- /dev/null +++ b/torch_geometric/loader/rag_loader.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from torch_geometric.data import Data, FeatureStore, HeteroData +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +class RAGFeatureStore(Protocol): + """Feature store for remote GNN RAG backend.""" + @abstractmethod + def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: + """Makes a comparison between the query and all the nodes to get all + the closest nodes. Return the indices of the nodes that are to be seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: + """Makes a comparison between the query and all the edges to get all + the closest nodes. Returns the edge indices that are to be the seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + """Combines sampled subgraph output with features in a Data object.""" + ... + + +class RAGGraphStore(Protocol): + """Graph store for remote GNN RAG backend.""" + @abstractmethod + def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, + **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample a subgraph using the seeded nodes and edges.""" + ... + + @abstractmethod + def register_feature_store(self, feature_store: FeatureStore): + """Register a feature store to be used with the sampler. Samplers need + info from the feature store in order to work properly on HeteroGraphs. + """ + ... + + +# TODO: Make compatible with Heterographs + + +class RAGQueryLoader: + def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None): + """Loader meant for making queries from a remote backend. + + Args: + data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore + and GraphStore to load from. Assumed to conform to the + protocols listed above. + local_filter (Optional[Callable[[Data, Any], Data]], optional): + Optional local transform to apply to data after retrieval. + Defaults to None. + seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters + to pass into process for fetching seed nodes. Defaults to None. + seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters + to pass into process for fetching seed edges. Defaults to None. + sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for sampling graph. Defaults to None. + loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for loading graph features. Defaults to None. + """ + fstore, gstore = data + self.feature_store = fstore + self.graph_store = gstore + self.graph_store.register_feature_store(self.feature_store) + self.local_filter = local_filter + self.seed_nodes_kwargs = seed_nodes_kwargs or {} + self.seed_edges_kwargs = seed_edges_kwargs or {} + self.sampler_kwargs = sampler_kwargs or {} + self.loader_kwargs = loader_kwargs or {} + + def query(self, query: Any) -> Data: + """Retrieve a subgraph associated with the query with all its feature + attributes. + """ + seed_nodes = self.feature_store.retrieve_seed_nodes( + query, **self.seed_nodes_kwargs) + seed_edges = self.feature_store.retrieve_seed_edges( + query, **self.seed_edges_kwargs) + + subgraph_sample = self.graph_store.sample_subgraph( + seed_nodes, seed_edges, **self.sampler_kwargs) + + data = self.feature_store.load_subgraph(sample=subgraph_sample, + **self.loader_kwargs) + + if self.local_filter: + data = self.local_filter(data, query) + return data diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py index 6f8fbcc644dc..f7529ae721b7 100644 --- a/torch_geometric/nn/models/g_retriever.py +++ b/torch_geometric/nn/models/g_retriever.py @@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module): (default: :obj:`False`) mlp_out_channels (int, optional): The size of each graph embedding after projection. (default: :obj:`4096`) + mlp_out_tokens (int, optional): Number of LLM prefix tokens to + reserve for GNN output. (default: :obj:`1`) .. warning:: This module has been tested with the following HuggingFace models @@ -43,6 +45,7 @@ def __init__( gnn: torch.nn.Module, use_lora: bool = False, mlp_out_channels: int = 4096, + mlp_out_tokens: int = 1, ) -> None: super().__init__() @@ -77,7 +80,9 @@ def __init__( self.projector = torch.nn.Sequential( torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), torch.nn.Sigmoid(), - torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), + torch.nn.Linear(mlp_hidden_channels, + mlp_out_channels * mlp_out_tokens), + torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)), ).to(self.llm.device) def encode( @@ -126,6 +131,9 @@ def forward( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) @@ -182,6 +190,9 @@ def inference( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py index 833ee657d0e7..22d3039f4c83 100644 --- a/torch_geometric/profile/__init__.py +++ b/torch_geometric/profile/__init__.py @@ -20,6 +20,7 @@ get_gpu_memory_from_nvidia_smi, get_model_size, ) +from .nvtx import nvtxit __all__ = [ 'profileit', @@ -38,6 +39,7 @@ 'get_gpu_memory_from_nvidia_smi', 'get_gpu_memory_from_ipex', 'benchmark', + 'nvtxit', ] classes = __all__ diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py new file mode 100644 index 000000000000..8dbce375ae5a --- /dev/null +++ b/torch_geometric/profile/nvtx.py @@ -0,0 +1,66 @@ +from functools import wraps +from typing import Optional + +import torch + +CUDA_PROFILE_STARTED = False + + +def begin_cuda_profile(): + global CUDA_PROFILE_STARTED + prev_state = CUDA_PROFILE_STARTED + if prev_state is False: + CUDA_PROFILE_STARTED = True + torch.cuda.cudart().cudaProfilerStart() + return prev_state + + +def end_cuda_profile(prev_state: bool): + global CUDA_PROFILE_STARTED + CUDA_PROFILE_STARTED = prev_state + if prev_state is False: + torch.cuda.cudart().cudaProfilerStop() + + +def nvtxit(name: Optional[str] = None, n_warmups: int = 0, + n_iters: Optional[int] = None): + """Enables NVTX profiling for a function. + + Args: + name (Optional[str], optional): Name to give the reference frame for + the function being wrapped. Defaults to the name of the + function in code. + n_warmups (int, optional): Number of iters to call that function + before starting. Defaults to 0. + n_iters (Optional[int], optional): Number of iters of that function to + record. Defaults to all of them. + """ + def nvtx(func): + + nonlocal name + iters_so_far = 0 + if name is None: + name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal iters_so_far + if not torch.cuda.is_available(): + return func(*args, **kwargs) + elif iters_so_far < n_warmups: + iters_so_far += 1 + return func(*args, **kwargs) + elif n_iters is None or iters_so_far < n_iters + n_warmups: + prev_state = begin_cuda_profile() + torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") + result = func(*args, **kwargs) + torch.cuda.nvtx.range_pop() + end_cuda_profile(prev_state) + iters_so_far += 1 + return result + else: + return func(*args, **kwargs) + + return wrapper + + return nvtx