From 89985d8a2162acf985c0b2f09f629d9516480f57 Mon Sep 17 00:00:00 2001 From: Zack Aristei Date: Thu, 15 Aug 2024 16:51:09 -0700 Subject: [PATCH] forgot a few more test scripts --- .../llm_plus_gnn/test_nvtx_rag_backend.py | 101 ++++++++++++++++++ examples/llm_plus_gnn/test_nvtx_uwebqsp.py | 27 +++++ examples/llm_plus_gnn/test_nvtx_webqsp.py | 19 ++++ 3 files changed, 147 insertions(+) create mode 100644 examples/llm_plus_gnn/test_nvtx_rag_backend.py create mode 100644 examples/llm_plus_gnn/test_nvtx_uwebqsp.py create mode 100644 examples/llm_plus_gnn/test_nvtx_webqsp.py diff --git a/examples/llm_plus_gnn/test_nvtx_rag_backend.py b/examples/llm_plus_gnn/test_nvtx_rag_backend.py new file mode 100644 index 0000000000000..b0f89bfcca619 --- /dev/null +++ b/examples/llm_plus_gnn/test_nvtx_rag_backend.py @@ -0,0 +1,101 @@ +# %% +from profiling_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore +from torch_geometric.loader import rag_loader +from torch_geometric.datasets import UpdatedWebQSPDataset +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.datasets.updated_web_qsp_dataset import preprocess_triplet +from torch_geometric.data import get_features_for_triplets_groups, Data +from itertools import chain +from torch_geometric.profile.nvtx import nvtxit +import torch +import argparse +from typing import Tuple + +# %% +# Patch FeatureStore and GraphStore + +SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes) +SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges) +SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph) +NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph) +rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +# %% +ds = UpdatedWebQSPDataset("small_ds_1", force_reload=True, limit=10) + +# %% +triplets = list(chain.from_iterable((d['graph'] for d in ds.raw_dataset))) + +# %% +questions = ds.raw_dataset['question'] + +# %% +ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet) +num_edges = len(ds.indexer._edges) + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer().to(device) + +# %% +fs, gs = create_remote_backend_from_triplets(triplets=triplets, node_embedding_model=model, node_method_to_call="encode", path="backend", pre_transform=preprocess_triplet, node_method_kwargs={"batch_size": 256}, graph_db=NeighborSamplingRAGGraphStore, feature_db=SentenceTransformerFeatureStore).load() + +# %% +from torch_geometric.datasets.updated_web_qsp_dataset import retrieval_via_pcst + +@nvtxit() +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, topk_e: int = 3, cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return graph + +# %% +query_loader = rag_loader.RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10}, sampler_kwargs={"num_neighbors": [40]*10}, local_filter=apply_retrieval_via_pcst) + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e-(subg_e | gt_e)) + return (tp+tn)/num_edges +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% + +@nvtxit() +def _run_eval(): + for subg, gt in zip((query_loader.query(q) for q in questions), ground_truth_graphs): + print(check_retrieval_accuracy(subg, gt, num_edges), check_retrieval_precision(subg, gt), check_retrieval_recall(subg, gt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + _run_eval() + else: + _run_eval() \ No newline at end of file diff --git a/examples/llm_plus_gnn/test_nvtx_uwebqsp.py b/examples/llm_plus_gnn/test_nvtx_uwebqsp.py new file mode 100644 index 0000000000000..6204fe603b84e --- /dev/null +++ b/examples/llm_plus_gnn/test_nvtx_uwebqsp.py @@ -0,0 +1,27 @@ +from torch_geometric.datasets import updated_web_qsp_dataset +from torch_geometric.profile import nvtxit +import torch +import argparse + +# Apply Patches +updated_web_qsp_dataset.UpdatedWebQSPDataset.process = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset.process) +updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph) +updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs) +updated_web_qsp_dataset.SentenceTransformer.encode = nvtxit()(updated_web_qsp_dataset.SentenceTransformer.encode) +updated_web_qsp_dataset.retrieval_via_pcst = nvtxit()(updated_web_qsp_dataset.retrieval_via_pcst) + +updated_web_qsp_dataset.get_features_for_triplets_groups = nvtxit()(updated_web_qsp_dataset.get_features_for_triplets_groups) +updated_web_qsp_dataset.LargeGraphIndexer.get_node_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features) +updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features) +updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter) +updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True) + else: + ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True) \ No newline at end of file diff --git a/examples/llm_plus_gnn/test_nvtx_webqsp.py b/examples/llm_plus_gnn/test_nvtx_webqsp.py new file mode 100644 index 0000000000000..1db8e5934728a --- /dev/null +++ b/examples/llm_plus_gnn/test_nvtx_webqsp.py @@ -0,0 +1,19 @@ +from torch_geometric.datasets import web_qsp_dataset +from torch_geometric.profile import nvtxit +import torch +import argparse + +# Apply Patches +web_qsp_dataset.retrieval_via_pcst = nvtxit()(web_qsp_dataset.retrieval_via_pcst) +web_qsp_dataset.WebQSPDataset.process = nvtxit()(web_qsp_dataset.WebQSPDataset.process) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = web_qsp_dataset.WebQSPDataset('baseline') + else: + ds = web_qsp_dataset.WebQSPDataset('baseline') \ No newline at end of file