diff --git a/examples/llm_plus_gnn/nvtx_rag_backend_example.py b/examples/llm_plus_gnn/nvtx_examples/nvtx_rag_backend_example.py similarity index 100% rename from examples/llm_plus_gnn/nvtx_rag_backend_example.py rename to examples/llm_plus_gnn/nvtx_examples/nvtx_rag_backend_example.py diff --git a/examples/llm_plus_gnn/nvtx_uwebqsp_example.py b/examples/llm_plus_gnn/nvtx_examples/nvtx_uwebqsp_example.py similarity index 100% rename from examples/llm_plus_gnn/nvtx_uwebqsp_example.py rename to examples/llm_plus_gnn/nvtx_examples/nvtx_uwebqsp_example.py diff --git a/examples/llm_plus_gnn/nvtx_webqsp_example.py b/examples/llm_plus_gnn/nvtx_examples/nvtx_webqsp_example.py similarity index 100% rename from examples/llm_plus_gnn/nvtx_webqsp_example.py rename to examples/llm_plus_gnn/nvtx_examples/nvtx_webqsp_example.py diff --git a/examples/llm_plus_gnn/nvtx_run.sh b/examples/llm_plus_gnn/nvtx_run.sh index b40c32abb0ea..4c6fce7c8224 100755 --- a/examples/llm_plus_gnn/nvtx_run.sh +++ b/examples/llm_plus_gnn/nvtx_run.sh @@ -24,4 +24,4 @@ python_file=$(basename "$1") # Run nsys profile on the Python file nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" -echo "Profile data saved as profile_${python_file%.py}.nsys-rep" \ No newline at end of file +echo "Profile data saved as profile_${python_file%.py}.nsys-rep" diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py index 122ddfad916f..8dbce375ae5a 100644 --- a/torch_geometric/profile/nvtx.py +++ b/torch_geometric/profile/nvtx.py @@ -24,6 +24,17 @@ def end_cuda_profile(prev_state: bool): def nvtxit(name: Optional[str] = None, n_warmups: int = 0, n_iters: Optional[int] = None): + """Enables NVTX profiling for a function. + + Args: + name (Optional[str], optional): Name to give the reference frame for + the function being wrapped. Defaults to the name of the + function in code. + n_warmups (int, optional): Number of iters to call that function + before starting. Defaults to 0. + n_iters (Optional[int], optional): Number of iters of that function to + record. Defaults to all of them. + """ def nvtx(func): nonlocal name