Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NVTX Profiler changes #3

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `profiler.nvtxit` with some examples ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Added `nn.models.GRetriever` ([#9167](https://github.com/pyg-team/pytorch_geometric/pull/9167))
Expand Down
95 changes: 88 additions & 7 deletions examples/llm_plus_gnn/doc/1_Retrieval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we will explore how to construct a RAG retrieval algorithm from a given subgraph."
"In this notebook, we will explore how to construct a RAG retrieval algorithm from a given subgraph, and conduct some experiments to evaluate its runtime performance."
]
},
{
Expand All @@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -49,7 +49,7 @@
"<IPython.core.display.Image object>"
]
},
"execution_count": 37,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -20285,11 +20285,92 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Evaluating Runtime Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pytorch Geometric provides multiple methods for evalutaing runtime performance. In this notebook, we utilize NVTX to profile the different components of our RAG Query Loader."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The method `nvtxit` allows for profiling the utilization and timings of any methods that get wrapped by it in a Python script."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To see an example of this, check out `nvtx_examples/nvtx_rag_backend_example.py`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This script mirrors this notebook's functionality, but notably, it includes the following code snippet:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"# Patch FeatureStore and GraphStore\n",
"\n",
"SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes)\n",
"SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges)\n",
"SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph)\n",
"NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph)\n",
"rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Importantly, this snippet wraps the methods of FeatureStore, GraphStore, and the Query method from QueryLoader so that it will be recognized as a unique frame in NVTX."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This can be executed by the included shell script `nvtx_run.sh`:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```bash\n",
"...\n",
"\n",
"# Get the base name of the Python file\n",
"python_file=$(basename \"$1\")\n",
"\n",
"# Run nsys profile on the Python file\n",
"nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python \"$1\"\n",
"\n",
"echo \"Profile data saved as profile_${python_file%.py}.nsys-rep\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The generated resulting `.nsys-rep` file can be visualized using tools like Nsight Systems or Nsight Compute, that can show the relative timings of the FeatureStore, GraphStore, and QueryLoader methods."
]
}
],
"metadata": {
Expand Down
101 changes: 101 additions & 0 deletions examples/llm_plus_gnn/nvtx_examples/nvtx_rag_backend_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# %%
from profiling_utils import create_remote_backend_from_triplets
from rag_feature_store import SentenceTransformerFeatureStore
from rag_graph_store import NeighborSamplingRAGGraphStore
from torch_geometric.loader import rag_loader
from torch_geometric.datasets import UpdatedWebQSPDataset
from torch_geometric.nn.nlp import SentenceTransformer
from torch_geometric.datasets.updated_web_qsp_dataset import preprocess_triplet
from torch_geometric.data import get_features_for_triplets_groups, Data
from itertools import chain
from torch_geometric.profile.nvtx import nvtxit
import torch
import argparse
from typing import Tuple

# %%
# Patch FeatureStore and GraphStore

SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_nodes)
SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()(SentenceTransformerFeatureStore.retrieve_seed_edges)
SentenceTransformerFeatureStore.load_subgraph = nvtxit()(SentenceTransformerFeatureStore.load_subgraph)
NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()(NeighborSamplingRAGGraphStore.sample_subgraph)
rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query)

# %%
ds = UpdatedWebQSPDataset("small_ds_1", force_reload=True, limit=10)

# %%
triplets = list(chain.from_iterable((d['graph'] for d in ds.raw_dataset)))

# %%
questions = ds.raw_dataset['question']

# %%
ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet)
num_edges = len(ds.indexer._edges)

# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer().to(device)

# %%
fs, gs = create_remote_backend_from_triplets(triplets=triplets, node_embedding_model=model, node_method_to_call="encode", path="backend", pre_transform=preprocess_triplet, node_method_kwargs={"batch_size": 256}, graph_db=NeighborSamplingRAGGraphStore, feature_db=SentenceTransformerFeatureStore).load()

# %%
from torch_geometric.datasets.updated_web_qsp_dataset import retrieval_via_pcst

@nvtxit()
def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, topk_e: int = 3, cost_e: float = 0.5) -> Tuple[Data, str]:
q_emb = model.encode(query)
textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk, topk_e, cost_e)
out_graph["desc"] = desc
return graph

# %%
query_loader = rag_loader.RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, seed_edges_kwargs={"k_edges": 10}, sampler_kwargs={"num_neighbors": [40]*10}, local_filter=apply_retrieval_via_pcst)

# %%
# Accuracy Metrics to be added to Profiler
def _eidx_helper(subg: Data, ground_truth: Data):
subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
if isinstance(subg_eidx, torch.Tensor):
subg_eidx = subg_eidx.tolist()
if isinstance(gt_eidx, torch.Tensor):
gt_eidx = gt_eidx.tolist()
subg_e = set(subg_eidx)
gt_e = set(gt_eidx)
return subg_e, gt_e
def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
total_e = set(range(num_edges))
tp = len(subg_e & gt_e)
tn = len(total_e-(subg_e | gt_e))
return (tp+tn)/num_edges
def check_retrieval_precision(subg: Data, ground_truth: Data):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
return len(subg_e & gt_e) / len(subg_e)
def check_retrieval_recall(subg: Data, ground_truth: Data):
subg_e, gt_e = _eidx_helper(subg, ground_truth)
return len(subg_e & gt_e) / len(gt_e)


# %%

@nvtxit()
def _run_eval():
for subg, gt in zip((query_loader.query(q) for q in questions), ground_truth_graphs):
print(check_retrieval_accuracy(subg, gt, num_edges), check_retrieval_precision(subg, gt), check_retrieval_recall(subg, gt))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
_run_eval()
else:
_run_eval()
27 changes: 27 additions & 0 deletions examples/llm_plus_gnn/nvtx_examples/nvtx_uwebqsp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torch_geometric.datasets import updated_web_qsp_dataset
from torch_geometric.profile import nvtxit
import torch
import argparse

# Apply Patches
updated_web_qsp_dataset.UpdatedWebQSPDataset.process = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset.process)
updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._build_graph)
updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs = nvtxit()(updated_web_qsp_dataset.UpdatedWebQSPDataset._retrieve_subgraphs)
updated_web_qsp_dataset.SentenceTransformer.encode = nvtxit()(updated_web_qsp_dataset.SentenceTransformer.encode)
updated_web_qsp_dataset.retrieval_via_pcst = nvtxit()(updated_web_qsp_dataset.retrieval_via_pcst)

updated_web_qsp_dataset.get_features_for_triplets_groups = nvtxit()(updated_web_qsp_dataset.get_features_for_triplets_groups)
updated_web_qsp_dataset.LargeGraphIndexer.get_node_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features)
updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features)
updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_edge_features_iter)
updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter = nvtxit()(updated_web_qsp_dataset.LargeGraphIndexer.get_node_features_iter)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True)
else:
ds = updated_web_qsp_dataset.UpdatedWebQSPDataset('update_ds', force_reload=True)
19 changes: 19 additions & 0 deletions examples/llm_plus_gnn/nvtx_examples/nvtx_webqsp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch_geometric.datasets import web_qsp_dataset
from torch_geometric.profile import nvtxit
import torch
import argparse

# Apply Patches
web_qsp_dataset.retrieval_via_pcst = nvtxit()(web_qsp_dataset.retrieval_via_pcst)
web_qsp_dataset.WebQSPDataset.process = nvtxit()(web_qsp_dataset.WebQSPDataset.process)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
args = parser.parse_args()
if args.capture_torch_kernels:
with torch.autograd.profiler.emit_nvtx():
ds = web_qsp_dataset.WebQSPDataset('baseline')
else:
ds = web_qsp_dataset.WebQSPDataset('baseline')
27 changes: 27 additions & 0 deletions examples/llm_plus_gnn/nvtx_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/sh

# Check if the user provided a Python file
if [ -z "$1" ]; then
echo "Usage: $0 <python_file>"
exit 1
fi

# Check if the provided file exists
if [[ ! -f "$1" ]]; then
echo "Error: File '$1' does not exist."
exit 1
fi

# Check if the provided file is a Python file
if [[ ! "$1" == *.py ]]; then
echo "Error: '$1' is not a Python file."
exit 1
fi

# Get the base name of the Python file
python_file=$(basename "$1")

# Run nsys profile on the Python file
nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1"

echo "Profile data saved as profile_${python_file%.py}.nsys-rep"
2 changes: 2 additions & 0 deletions torch_geometric/profile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_gpu_memory_from_nvidia_smi,
get_model_size,
)
from .nvtx import nvtxit

__all__ = [
'profileit',
Expand All @@ -38,6 +39,7 @@
'get_gpu_memory_from_nvidia_smi',
'get_gpu_memory_from_ipex',
'benchmark',
'nvtxit',
]

classes = __all__
66 changes: 66 additions & 0 deletions torch_geometric/profile/nvtx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from functools import wraps
from typing import Optional

import torch

CUDA_PROFILE_STARTED = False


def begin_cuda_profile():
global CUDA_PROFILE_STARTED
prev_state = CUDA_PROFILE_STARTED
if prev_state is False:
CUDA_PROFILE_STARTED = True
torch.cuda.cudart().cudaProfilerStart()
return prev_state


def end_cuda_profile(prev_state: bool):
global CUDA_PROFILE_STARTED
CUDA_PROFILE_STARTED = prev_state
if prev_state is False:
torch.cuda.cudart().cudaProfilerStop()


def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
n_iters: Optional[int] = None):
"""Enables NVTX profiling for a function.

Args:
name (Optional[str], optional): Name to give the reference frame for
the function being wrapped. Defaults to the name of the
function in code.
n_warmups (int, optional): Number of iters to call that function
before starting. Defaults to 0.
n_iters (Optional[int], optional): Number of iters of that function to
record. Defaults to all of them.
"""
def nvtx(func):

nonlocal name
iters_so_far = 0
if name is None:
name = func.__name__

@wraps(func)
def wrapper(*args, **kwargs):
nonlocal iters_so_far
if not torch.cuda.is_available():
return func(*args, **kwargs)
elif iters_so_far < n_warmups:
iters_so_far += 1
return func(*args, **kwargs)
elif n_iters is None or iters_so_far < n_iters + n_warmups:
prev_state = begin_cuda_profile()
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
result = func(*args, **kwargs)
torch.cuda.nvtx.range_pop()
end_cuda_profile(prev_state)
iters_so_far += 1
return result
else:
return func(*args, **kwargs)

return wrapper

return nvtx
Loading