Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

G-retriever API updates (NVTX, Remote Backend, Large Graph Indexer, Examples) #9666

Merged
merged 1,012 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1012 commits
Select commit Hold shift + click to select a range
b912652
finished dataloader
May 30, 2024
ce5370b
finished dataloader 2
May 30, 2024
dc3f9f4
add to data transform
May 30, 2024
4488d70
init commit of dataset
May 31, 2024
eb9b761
debug pcst part 1
May 31, 2024
85da0e7
push what i have for now
May 31, 2024
168a37a
fixes after rebase
Jun 1, 2024
2cf727f
more fixes and mocking
Jun 1, 2024
9ab9b22
more mocking
Jun 1, 2024
6a1ee0f
main mocking working
Jun 3, 2024
03f6f4e
pr fixes
Jun 3, 2024
0b4ee3d
migrate large graph indexer
Jun 3, 2024
81a1b8a
fix save overrider
Jun 3, 2024
dd2ec56
formatting 1
Jun 3, 2024
917fe54
formatting 2
Jun 3, 2024
ae8d057
start unittests
Jun 4, 2024
7fdae85
tests done for largegraphindexer
Jun 4, 2024
5f67a84
tests done for largegraphindexer 2
Jun 4, 2024
b0ab66f
tests done for largegraphindexer 3
Jun 4, 2024
bfd1a2f
tests done for largegraphindexer 4
Jun 4, 2024
bcee67a
migrate updated qsp dataset
Jun 4, 2024
2117605
formatting
Jun 4, 2024
866866c
add dataset
Jun 4, 2024
70f0380
add option for new dataloader
Jun 4, 2024
db94397
fix formatting
Jun 4, 2024
a03028a
instantiate ds in train func
Jun 4, 2024
0ec7aac
Restore mapping attrs
zaristei Jun 4, 2024
c1e5cc8
Test edited to cover restoration of data info
zaristei Jun 4, 2024
5a09881
begin trying profiling
Jun 5, 2024
9d1093c
speedup retrieval to avoid timeouts
Jun 6, 2024
27acc5b
init commit for retrieval
Jun 11, 2024
ff24ce9
profiling nb
Jun 11, 2024
8180fa6
pre commit
Jun 11, 2024
2e51703
debug part 1
Jun 12, 2024
0ac3258
pre commit
Jun 12, 2024
e6b2cd0
debugging done, begin qol cleanup and writing tests
Jun 13, 2024
a782e02
fstore and gstore configurable from loader
Jun 14, 2024
0e58080
move rag loader
Jun 14, 2024
8fd61f4
pre commit
Jun 14, 2024
7f001c8
pre transform working
Jun 17, 2024
53ab64a
fix saving bug for pcst hack
Jun 18, 2024
c1fae6f
fix bug in profiler and make blank file to be filled with unittests
Jun 18, 2024
a051259
overhaul test profiling
Jun 18, 2024
3c60f03
begin nvtx profiler
Jun 26, 2024
5c3876d
nvtx profiler optional name
Jun 28, 2024
a2b606c
nvtx profiler optional name 2
Jun 28, 2024
0a588f6
precommit
Jun 28, 2024
f0d7aa3
nvtx shell script
Jun 28, 2024
8a4d0cf
nvtx scripts added for kg loading and backend retrieval
Jun 28, 2024
01696e4
nvtx_webqsp_script
Jun 28, 2024
42cc16e
cudaprofilerapi hook
Jun 28, 2024
d390cce
make tracking the kernels optional
Jun 29, 2024
3c2a20d
nvtx test scripts
Jul 2, 2024
e91a951
bugfixes
Jul 3, 2024
7ddee4d
code to parallelize checking step in g_retriever
Jul 8, 2024
18eabf4
refactor training and demo loop to allow for multiple model benchmarks
Jul 11, 2024
e120760
benchmark script for current experiment
Jul 11, 2024
71788de
gretriever more embedding size
Jul 16, 2024
bf220d1
override switch
Jul 24, 2024
c347994
override switch 2
Jul 24, 2024
3611aa5
rag generation script
Jul 27, 2024
953262f
wikidata dataset
Aug 8, 2024
97d15e2
documentation part 1
Aug 9, 2024
d58568e
multihop
Aug 15, 2024
77cd9a7
Prune out everything that doesn't have to do with LargeGraphIndexer
zaristei Aug 15, 2024
8c340af
fix imports
zaristei Aug 16, 2024
8947a9d
lint
Aug 16, 2024
8b5a462
mypy purgatory 1
Aug 16, 2024
d4c20c4
mypy purgatory 2
Aug 17, 2024
d0c6dfa
mypy purgatory 3
Aug 17, 2024
b835caa
Changelog 0
Aug 17, 2024
79d70c3
mypy purgatory 4
Aug 17, 2024
a15b141
docstrings
Aug 21, 2024
cc3543d
unittest for updated qsp
Aug 21, 2024
ea3c0be
address pr feedback
Aug 27, 2024
eac09a8
Add back in all Remote backend deltas
zaristei Aug 15, 2024
79784cb
changelog 0
Aug 17, 2024
4b0e658
RAG Backend documentation 0
Aug 21, 2024
4faad32
rag generate running on whole dataset now
Aug 21, 2024
567b43e
docstrings
Aug 21, 2024
0bdcc6d
drop performance ref
Aug 23, 2024
36426bd
Add back in nvtx profiler code
zaristei Aug 15, 2024
0927f48
forgot a file
zaristei Aug 15, 2024
109420b
forgot a few more test scripts
zaristei Aug 15, 2024
d7ea7ef
rename to not get detected by pytest
Aug 17, 2024
e63af63
changelog 0
Aug 17, 2024
05269ca
docstrings
Aug 21, 2024
11293ff
nvtx performance additons
Aug 24, 2024
530405c
Changes to allow for multiprocessing in hallucination detection
zaristei Aug 15, 2024
35b8dd0
restore changes that were for benchmarking and experiments
zaristei Aug 15, 2024
46a7d7a
changelog 0
Aug 17, 2024
85bfeaa
docstrings
Aug 21, 2024
23872a3
passs in arbitrary dataset path
Aug 24, 2024
a22cb82
automate gretriever benchmarking for other subgraphs
Aug 24, 2024
117ec9e
begin incorporating approx knn so that dataset can be fully generated
Aug 26, 2024
e23889c
experimental batch query support for rag loader
Aug 26, 2024
68a2070
rag generate specify num samples
Aug 26, 2024
3d5c2cb
switch to approx knn for rag generate multihop example
Aug 26, 2024
9c972e7
argparse for rag generate
Aug 27, 2024
518a9e3
move all benchmark code into own directory
Aug 27, 2024
59564a5
beginning of refactor into doc
Aug 28, 2024
05b95c1
reorder files
Aug 28, 2024
770a7d1
begin merger of uwebqsp and webqsp
Aug 28, 2024
e660a9b
rename
Aug 28, 2024
759f0f8
merger 2
Aug 28, 2024
e5668bd
verbose
Aug 28, 2024
dc54508
out file
Aug 28, 2024
0359b09
rm separate file for benchmarking multitoken gnns and readd changes t…
Aug 29, 2024
e045ed4
rm from benchmark section
Aug 29, 2024
60bd11f
fix formatting
Aug 29, 2024
e748cfc
check dependencies for webqsp
Aug 29, 2024
664cf7a
multihop scripts
Aug 29, 2024
7472539
adjust multihop
Aug 29, 2024
9f2e618
adjust multihop 2
Aug 29, 2024
000931c
more tests for web qsp
Aug 29, 2024
db8e004
fix unittests
Aug 29, 2024
99c8db5
rag docs part 2
Aug 29, 2024
ca54720
add images
Aug 29, 2024
2a49022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
5a97400
precommit
Aug 29, 2024
fb57e26
ipython3 to python
Aug 29, 2024
30fe562
remove batch query for now
Aug 29, 2024
27843dc
precommit 2
Aug 29, 2024
204ac8f
clean up code in docs 1
Aug 29, 2024
3c05386
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
20fc46d
remove tqdm bars
Aug 29, 2024
1a33f5a
nvtx unit tests
Aug 29, 2024
9aa4244
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
5915b3d
make linter happy
Aug 29, 2024
137d7a2
renamed some files to prevent misnomers
Aug 29, 2024
fafbb9b
typo
Aug 29, 2024
e14acc0
typo 2
Aug 29, 2024
62ce0b2
typo 3
Aug 29, 2024
a90b5d9
fix retrieval test
Aug 29, 2024
018bdcf
fix multihop
Aug 29, 2024
1ff34c8
remove ipynbs due to them being superseded by the rsts
Aug 29, 2024
a597870
readme in llm_plus_gnn
Aug 29, 2024
e5e5453
fix multihop 2
Aug 29, 2024
a959118
linting
Aug 29, 2024
0ef2bbc
close parenthesis
Aug 29, 2024
6fab6a5
fix multihop 3
Aug 29, 2024
d9ef859
fix typo in example
Aug 29, 2024
352c4ab
multihop to multihop rag
Aug 30, 2024
97ab60a
linting
Aug 30, 2024
ab53471
typo
Aug 30, 2024
99e715d
fix bug where force flag wasnt being acknoledged for LLM results
Aug 30, 2024
3eea912
restore patch from rishis commits
Aug 30, 2024
6560566
add verbosity flag for webqsp
Aug 30, 2024
9852a76
fix force reload for webqsp and set to default
Aug 30, 2024
f41e8d6
update llama2 url
Aug 31, 2024
e1f8f62
Merge branch 'pyg-team:master' into zaristei/g_retriever_experiments
zaristei Sep 1, 2024
60ef416
restore webqsp as default
zacool64 Sep 2, 2024
d06a285
fix skip untuned arg issues
zacool64 Sep 2, 2024
f88be6b
reform file structure and add more documentation
zaristei Sep 4, 2024
e50f776
Merge branch 'master' into zaristei/g_retriever_experiments
zaristei Sep 26, 2024
7f5d8f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
cd0de7a
fix bad inport
zaristei Sep 26, 2024
a496480
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
884ffe5
bug fix 1
zacool64 Sep 30, 2024
52944f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
d096fa7
rename example dir
zaristei Sep 30, 2024
b89ec88
pre-commit
zacool64 Sep 30, 2024
2b0ccb4
pre-commit
zacool64 Sep 30, 2024
db8085d
mypy
zacool64 Sep 30, 2024
1779ce7
fix import change
zacool64 Sep 30, 2024
7aead2a
mypy 2
zacool64 Sep 30, 2024
d13d2ea
mypy 3
zacool64 Sep 30, 2024
04ca150
mypy 4
zacool64 Sep 30, 2024
8089b95
fix nvtx examples
zacool64 Sep 30, 2024
7bc91d9
Merge branch 'zaristei/modded_master' into zaristei/g_retriever_exper…
zaristei Oct 4, 2024
453d61c
Rename to path on master branch
zaristei Oct 4, 2024
29cd5bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2024
f607374
Merge remote-tracking branch 'origin/master' into zaristei/g_retrieve…
zaristei Oct 25, 2024
374ba5c
Cherry pick changes from larger commit that pertain only to large gra…
zaristei Oct 25, 2024
a43ce62
restore access pattern from master
zacool64 Oct 25, 2024
56a732f
Re-add content associated with multi-hop RAG
zaristei Oct 25, 2024
bf5ffb2
Re-add only content associated with NVTX profiling
zaristei Oct 25, 2024
10bc8f2
move min demo into g_retriever_utils (temporarily)
zacool64 Oct 28, 2024
37ce77f
Move min demo into g_retriever_utils (temporarily)
zacool64 Oct 28, 2024
9a04aff
restore g_retriever example from master
zacool64 Oct 28, 2024
7656d32
fix test
zacool64 Oct 28, 2024
0ce7086
removing readme changes
puririshi98 Oct 28, 2024
2e174b9
fix merge conflict
zacool64 Oct 28, 2024
17586ac
validation to val
zacool64 Oct 28, 2024
9121a80
property method issues
zacool64 Oct 29, 2024
981c9a4
property method issues 2
zacool64 Oct 29, 2024
9c98ed2
.index() instead of .index[]
zacool64 Oct 29, 2024
1581928
multiple splits fix pt 1
zacool64 Oct 30, 2024
7554175
OOM after preproc but then if rerun it runs fine
puririshi98 Oct 30, 2024
50be9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
fff04da
update changelog
zacool64 Oct 31, 2024
371b985
fix oom pt 1
zacool64 Oct 31, 2024
45cde97
revert changes that lead to oom
zacool64 Oct 31, 2024
2689fe4
change access pattern so that oom less likely
zacool64 Oct 31, 2024
fd14441
oom fix pt 2
zacool64 Oct 31, 2024
2821344
fixing import issue
puririshi98 Nov 7, 2024
b3a593d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
15a659b
Update rag_generate.py
puririshi98 Nov 7, 2024
430026d
changing import structure so i can pull this into my pr: https://gith…
puririshi98 Nov 7, 2024
b8853b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
a9bbfcb
Merge branch 'master' into zaristei/g_retriever_experiments
puririshi98 Nov 7, 2024
3f2d3cf
changes for compatability with my PR
puririshi98 Nov 7, 2024
d87fd3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
5b9266f
indent fix
puririshi98 Nov 7, 2024
3cddbd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
e4313e8
putting back
puririshi98 Nov 7, 2024
09ab86b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
6548aad
reverting to first commit, will make a some changes to RAGQueryLoader…
puririshi98 Nov 7, 2024
9ebd380
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
a261f0b
reverting to zacks, i will make the necesarry extensive changes to ma…
puririshi98 Nov 7, 2024
1023c81
minor fix
puririshi98 Nov 8, 2024
0532184
revert webqsp for now
zacool64 Nov 12, 2024
ba773b2
Merge branch 'pyg-team:master' into zaristei/g_retriever_experiments_…
zaristei Nov 12, 2024
6096b98
Delete test/datasets/test_web_qsp_dataset.py
zaristei Nov 12, 2024
76d8f83
fix pkling of largegraphindexer
zacool64 Nov 12, 2024
118aff3
Merge branch 'zaristei/g_retriever_experiments' into zaristei/g_retri…
zacool64 Nov 12, 2024
c66d9fe
Merge branch 'master' into zaristei/g_retriever_experiments
puririshi98 Nov 20, 2024
0b699b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
732e6f9
Merge branch 'master' into zaristei/g_retriever_experiments_no_webqsp
zaristei Nov 23, 2024
0ea177d
backward compatibility
zaristei Nov 23, 2024
b87aedc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
dcc0d5f
fix networkx dep for old versions
zaristei Nov 23, 2024
21c2315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
27ddd35
fix networkx dep for old versions 2
zaristei Nov 23, 2024
15c74eb
Merge remote-tracking branch 'origin/zaristei/g_retriever_experiments…
zaristei Nov 23, 2024
d4278df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
b9f9cf8
Remove dupe webqsp
zaristei Nov 23, 2024
ec61ec3
make webqsp example use current webqsp implementation
zaristei Nov 23, 2024
a7f7f25
col limits on prints
zaristei Nov 23, 2024
23dd66c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
14ae2fb
col limits on prints 2
zaristei Nov 23, 2024
c1380ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
14e0376
col limits on prints 3
zaristei Nov 23, 2024
204ddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
0fa20ad
col limits on prints 4
zaristei Nov 23, 2024
8f63e66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
51129d0
col limits on prints 5
zaristei Nov 23, 2024
79c43b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2024
278cada
col limits on prints 6
zaristei Nov 23, 2024
581e184
pre-commit fix
zaristei Nov 23, 2024
257319a
Merge branch 'master' into zaristei/g_retriever_experiments
puririshi98 Nov 25, 2024
a7cfc61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
0e69cd0
large graph indexer test skip with wrong version
zaristei Nov 25, 2024
794e487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
f0ffc14
Remove outdated docs 1
zaristei Nov 25, 2024
147fc84
Remove outdated docs 2
zaristei Nov 25, 2024
2140250
Remove outdated docs 3
zaristei Nov 25, 2024
32ad70a
Remove outdated docs 4
zaristei Nov 25, 2024
13301f0
Remove outdated docs 5
zaristei Nov 25, 2024
b5069e2
lint fix
zaristei Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730))
- Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797))
- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710))
Expand Down
15 changes: 9 additions & 6 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| Example | Description |
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
4 changes: 4 additions & 0 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
https://github.com/neo4j-product-examples/neo4j-gnn-llm-example
"""
import argparse
import gc
import math
import os.path as osp
import re
Expand Down Expand Up @@ -145,6 +146,9 @@ def adjust_learning_rate(param_group, LR, epoch):
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)

# To clean up after Data Preproc
gc.collect()
torch.cuda.empty_cache()
gnn = GAT(
in_channels=1024,
hidden_channels=hidden_channels,
Expand Down
11 changes: 11 additions & 0 deletions examples/llm/g_retriever_utils/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Examples for LLM and GNN co-training

| Example | Description |
| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore |
| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) |
| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. |

NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes.
105 changes: 105 additions & 0 deletions examples/llm/g_retriever_utils/benchmark_model_archs_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Used to benchmark the performance of an untuned/fine tuned LLM against
GRetriever with various architectures and layer depths.
"""
# %%
import argparse
import sys

import torch

from torch_geometric.datasets import WebQSPDataset
from torch_geometric.nn.models import GAT, MLP, GRetriever

sys.path.append('..')
from minimal_demo import ( # noqa: E402 # isort:skip
benchmark_models, get_loss, inference_step,
)

# %%
parser = argparse.ArgumentParser(
description="""Benchmarker for GRetriever\n""" +
"""NOTE: Evaluating with smaller samples may result in poorer""" +
""" performance for the trained models compared to """ +
"""untrained models.""")
parser.add_argument("--hidden_channels", type=int, default=1024)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--tiny_llama", action='store_true')

parser.add_argument("--dataset_path", type=str, required=False)
# Default to WebQSP split
parser.add_argument("--num_train", type=int, default=2826)
parser.add_argument("--num_val", type=int, default=246)
parser.add_argument("--num_test", type=int, default=1628)

args = parser.parse_args()

# %%
hidden_channels = args.hidden_channels
lr = args.learning_rate
epochs = args.epochs
batch_size = args.batch_size
eval_batch_size = args.eval_batch_size

# %%
if not args.dataset_path:
ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True)
else:
# We just assume that the size of the dataset accomodates the
# train/val/test split, because checking may be expensive.
dataset = torch.load(args.dataset_path)

class MockDataset:
"""Utility class to patch the fields in WebQSPDataset used by
GRetriever.
"""
def __init__(self) -> None:
pass

@property
def split_idxs(self) -> dict:
# Imitates the WebQSP split method
return {
"train":
torch.arange(args.num_train),
"val":
torch.arange(args.num_val) + args.num_train,
"test":
torch.arange(args.num_test) + args.num_train + args.num_val,
}

def __getitem__(self, idx: int):
return dataset[idx]

ds = MockDataset()

# %%
model_names = []
model_classes = []
model_kwargs = []
model_type = ["GAT", "MLP"]
models = {"GAT": GAT, "MLP": MLP}
# Use to vary the depth of the GNN model
num_layers = [4]
# Use to vary the number of LLM tokens reserved for GNN output
num_tokens = [1]
for m_type in model_type:
for n_layer in num_layers:
for n_tokens in num_tokens:
model_names.append(f"{m_type}_{n_layer}_{n_tokens}")
model_classes.append(GRetriever)
kwargs = dict(gnn_hidden_channels=hidden_channels,
num_gnn_layers=n_layer, gnn_to_use=models[m_type],
mlp_out_tokens=n_tokens)
if args.tiny_llama:
kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1'
kwargs['mlp_out_dim'] = 2048
kwargs['num_llm_params'] = 1
model_kwargs.append(kwargs)

# %%
benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs,
batch_size, eval_batch_size, get_loss, inference_step,
skip_LLMs=False, tiny_llama=args.tiny_llama, force=True)
Loading
Loading