Skip to content

Commit

Permalink
forgot a few more test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
zaristei committed Aug 17, 2024
1 parent dc8a2bf commit 89985d8
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 0 deletions.
101 changes: 101 additions & 0 deletions examples/llm_plus_gnn/test_nvtx_rag_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# %%
from profiling_utils import create_remote_backend_from_triplets
from rag_feature_store import SentenceTransformerFeatureStore
from rag_graph_store import NeighborSamplingRAGGraphStore
from torch_geometric.loader import rag_loader
from torch_geometric.datasets import UpdatedWebQSPDataset
from torch_geometric.nn.nlp import SentenceTransformer
from torch_geometric.datasets.updated_web_qsp_dataset import preprocess_triplet
from torch_geometric.data import get_features_for_triplets_groups, Data
from itertools import chain
from torch_geometric.profile.nvtx import nvtxit
import torch
import argparse
from typing import Tuple

# %%
# Patch FeatureStore and GraphStore

SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes)
SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges)
SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph)
NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph)
rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query)

# %%
ds = UpdatedWebQSPDataset("small_ds_1", force_reload=True, limit=10)

# %%
triplets = list(chain.from_iterable((d['graph'] for d in ds.raw_dataset)))

# %%
questions = ds.raw_dataset['question']

# %%
ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet)
num_edges = len(ds.indexer._edges)

# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer().to(device)

# %%
fs, gs = create_remote_backend_from_triplets(triplets=triplets, node_embedding_model=model, node_method_to_call="encode", path="backend", pre_transform=preprocess_triplet, node_method_kwargs={"batch_size": 256}, graph_db=NeighborSamplingRAGGraphStore, feature_db=SentenceTransformerFeatureStore).load()

# %%
from torch_geometric.datasets.updated_web_qsp_dataset import retrieval_via_pcst

@nvtxit()
def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, topk_e: int = 3, cost_e: float = 0.5) -> Tuple[Data, str]:
q_emb = model.encode(query)
textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk, topk_e, cost_e)
out_graph["desc"] = desc
return graph

# %%
query_loader = rag_loader.RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10}, sampler_kwargs={"num_neighbors": [40]*10}, local_filter=apply_retrieval_via_pcst)

# %%
# Accuracy Metrics to be added to Profiler
def _eidx_helper(subg: Data, ground_truth: Data):
subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
if isinstance(subg_eidx, torch.Tensor):
subg_eidx = subg_eidx.tolist()
if isinstance(gt_eidx, torch.Tensor):
gt_eidx = gt_eidx.tolist()
subg_e = set(subg_eidx)
gt_e = set(gt_eidx)
return subg_e, gt_e
def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
total_e = set(range(num_edges))
tp = len(subg_e & gt_e)
tn = len(total_e-(subg_e | gt_e))
return (tp+tn)/num_edges
def check_retrieval_precision(subg: Data, ground_truth: Data):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
return len(subg_e & gt_e) / len(subg_e)
def check_retrieval_recall(subg: Data, ground_truth: Data):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
return len(subg_e & gt_e) / len(gt_e)


# %%

@nvtxit()
def _run_eval():
for subg, gt in zip((query_loader.query(q) for q in questions), ground_truth_graphs):
print(check_retrieval_accuracy(subg, gt, num_edges), check_retrieval_precision(subg, gt), check_retrieval_recall(subg, gt))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
_run_eval()
else:
_run_eval()
27 changes: 27 additions & 0 deletions examples/llm_plus_gnn/test_nvtx_uwebqsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torch_geometric.datasets import updated_web_qsp_dataset
from torch_geometric.profile import nvtxit
import torch
import argparse

# Apply Patches
updated_web_qsp_dataset.UpdatedWebQSPDataset.process = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset.process)
updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph)
updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs)
updated_web_qsp_dataset.SentenceTransformer.encode = nvtxit()(updated_web_qsp_dataset.SentenceTransformer.encode)
updated_web_qsp_dataset.retrieval_via_pcst = nvtxit()(updated_web_qsp_dataset.retrieval_via_pcst)

updated_web_qsp_dataset.get_features_for_triplets_groups = nvtxit()(updated_web_qsp_dataset.get_features_for_triplets_groups)
updated_web_qsp_dataset.LargeGraphIndexer.get_node_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features)
updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features)
updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter)
updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True)
else:
ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True)
19 changes: 19 additions & 0 deletions examples/llm_plus_gnn/test_nvtx_webqsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch_geometric.datasets import web_qsp_dataset
from torch_geometric.profile import nvtxit
import torch
import argparse

# Apply Patches
web_qsp_dataset.retrieval_via_pcst = nvtxit()(web_qsp_dataset.retrieval_via_pcst)
web_qsp_dataset.WebQSPDataset.process = nvtxit()(web_qsp_dataset.WebQSPDataset.process)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
ds = web_qsp_dataset.WebQSPDataset('baseline')
else:
ds = web_qsp_dataset.WebQSPDataset('baseline')

0 comments on commit 89985d8

Please sign in to comment.