Skip to content

Commit

Permalink
Add nn.models.GRetriever (#9480)
Browse files Browse the repository at this point in the history
1. #9462
2. **->** #9480
3. #9481
4. #9167

---

breaking #9167 down
further, focusing on G-retriever model this time

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
4 people authored Sep 10, 2024
1 parent df3df3b commit 6d9e850
Show file tree
Hide file tree
Showing 6 changed files with 416 additions and 117 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
Expand Down
53 changes: 53 additions & 0 deletions test/nn/models/test_g_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch

from torch_geometric.nn import GAT, GRetriever
from torch_geometric.nn.nlp import LLM
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('transformers', 'sentencepiece', 'accelerate')
def test_g_retriever() -> None:
llm = LLM(
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
num_params=1,
dtype=torch.float16,
)

gnn = GAT(
in_channels=1024,
out_channels=1024,
hidden_channels=1024,
num_layers=2,
heads=4,
norm='batch_norm',
)

model = GRetriever(
llm=llm,
gnn=gnn,
mlp_out_channels=2048,
)
assert str(model) == ('GRetriever(\n'
' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n'
' gnn=GAT(1024, 1024, num_layers=2),\n'
')')

x = torch.randn(10, 1024)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
])
edge_attr = torch.randn(edge_index.size(1), 1024)
batch = torch.zeros(x.size(0), dtype=torch.long)

question = ["Is PyG the best open-source GNN library?"]
label = ["yes!"]

# Test train:
loss = model(question, x, edge_index, batch, label, edge_attr)
assert loss >= 0

# Test inference:
pred = model.inference(question, x, edge_index, batch, edge_attr)
assert len(pred) == 1
5 changes: 3 additions & 2 deletions test/nn/nlp/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor

from torch_geometric.nn.nlp.llm import LLM
from torch_geometric.nn.nlp import LLM
from torch_geometric.testing import onlyFullTest, withPackage


Expand All @@ -12,10 +12,11 @@ def test_llm() -> None:
answer = ["yes!"]

model = LLM(
model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1",
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
num_params=1,
dtype=torch.float16,
)
assert str(model) == 'LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1)'

loss = model(question, answer)
assert isinstance(loss, Tensor)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .pmlp import PMLP
from .neural_fingerprint import NeuralFingerprint
from .visnet import ViSNet
from .g_retriever import GRetriever

# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
Expand Down Expand Up @@ -75,4 +76,5 @@
'PMLP',
'NeuralFingerprint',
'ViSNet',
'GRetriever',
]
205 changes: 205 additions & 0 deletions torch_geometric/nn/models/g_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import List, Optional

import torch
from torch import Tensor

from torch_geometric.nn.models import GAT
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
from torch_geometric.utils import scatter


class GRetriever(torch.nn.Module):
r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented
Generation for Textual Graph Understanding and Question Answering"
<https://arxiv.org/abs/2402.07630>`_ paper.
Args:
llm (LLM): The LLM to use.
gnn (torch.nn.Module): The GNN to use.
use_lora (bool, optional): If set to :obj:`True`, will use LORA from
:obj:`peft` for training the LLM, see
`here <https://huggingface.co/docs/peft/en/index>`_ for details.
(default: :obj:`False`)
mlp_out_channels (int, optional): The size of each graph embedding
after projection. (default: :obj:`4096`)
.. warning::
This module has been tested with the following HuggingFace models
* :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
* :obj:`llm_to_use="google/gemma-7b"`
and may not work with other models. See other models at `HuggingFace
Models <https://huggingface.co/models>`_ and let us know if you
encounter any issues.
.. note::
For an example of using :class:`GRetriever`, see
`examples/llm/g_retriever.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.
"""
def __init__(
self,
llm: LLM,
gnn: torch.nn.Module,
use_lora: bool = False,
gnn_to_use=GAT,
mlp_out_channels: int = 4096,
) -> None:
super().__init__()

self.llm = llm
self.gnn = gnn.to(self.llm.device)

self.word_embedding = self.llm.word_embedding
self.llm_generator = self.llm.llm
if use_lora:
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
self.llm_generator = prepare_model_for_kbit_training(
self.llm_generator)
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules = ['q_proj', 'v_proj']
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias='none',
task_type='CAUSAL_LM',
)
self.llm_generator = get_peft_model(self.llm_generator, config)

mlp_hidden_channels = self.gnn.out_channels
self.projector = torch.nn.Sequential(
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
torch.nn.Sigmoid(),
torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
).to(self.llm.device)

def encode(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_attr: Optional[Tensor],
) -> Tensor:
x = x.to(self.llm.device)
edge_index = edge_index.to(self.llm.device)
if edge_attr is not None:
edge_attr = edge_attr.to(self.llm.device)
batch = batch.to(self.llm.device)

out = self.gnn(x, edge_index, edge_attr=edge_attr)
return scatter(out, batch, dim=0, reduce='mean')

def forward(
self,
question: List[str],
x: Tensor,
edge_index: Tensor,
batch: Tensor,
label: List[str],
edge_attr: Optional[Tensor] = None,
additional_text_context: Optional[List[str]] = None,
):
r"""The forward pass.
Args:
question (List[str]): The questions/prompts.
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
batch (torch.Tensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example.
label (List[str]): The answers/labels.
edge_attr (torch.Tensor, optional): The edge features (if supported
by the GNN). (default: :obj:`None`)
additional_text_context (List[str], optional): Additional context
to give to the LLM, such as textified knowledge graphs.
(default: :obj:`None`)
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(x.size(0), dim=0)

(
inputs_embeds,
attention_mask,
label_input_ids,
) = self.llm._get_embeds(question, additional_text_context, xs, label)

with self.llm.autocast_context:
outputs = self.llm_generator(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=label_input_ids,
)

return outputs.loss

@torch.no_grad()
def inference(
self,
question: List[str],
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_attr: Optional[Tensor] = None,
additional_text_context: Optional[List[str]] = None,
max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
):
r"""The inference pass.
Args:
question (List[str]): The questions/prompts.
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
batch (torch.Tensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example.
edge_attr (torch.Tensor, optional): The edge features (if supported
by the GNN). (default: :obj:`None`)
additional_text_context (List[str], optional): Additional context
to give to the LLM, such as textified knowledge graphs.
(default: :obj:`None`)
max_out_tokens (int, optional): How many tokens for the LLM to
generate. (default: :obj:`32`)
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(x.size(0), dim=0)

inputs_embeds, attention_mask, _ = self.llm._get_embeds(
question, additional_text_context, xs)

bos_token = self.llm.tokenizer(
BOS,
add_special_tokens=False,
).input_ids[0]

with self.llm.autocast_context:
outputs = self.llm_generator.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=max_out_tokens,
attention_mask=attention_mask,
bos_token_id=bos_token,
use_cache=True # Important to set!
)

return self.llm.tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(\n'
f' llm={self.llm},\n'
f' gnn={self.gnn},\n'
f')')
Loading

0 comments on commit 6d9e850

Please sign in to comment.