-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. #9462 2. **->** #9480 3. #9481 4. #9167 --- breaking #9167 down further, focusing on G-retriever model this time --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: rusty1s <[email protected]>
- Loading branch information
1 parent
df3df3b
commit 6d9e850
Showing
6 changed files
with
416 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
|
||
from torch_geometric.nn import GAT, GRetriever | ||
from torch_geometric.nn.nlp import LLM | ||
from torch_geometric.testing import onlyFullTest, withPackage | ||
|
||
|
||
@onlyFullTest | ||
@withPackage('transformers', 'sentencepiece', 'accelerate') | ||
def test_g_retriever() -> None: | ||
llm = LLM( | ||
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', | ||
num_params=1, | ||
dtype=torch.float16, | ||
) | ||
|
||
gnn = GAT( | ||
in_channels=1024, | ||
out_channels=1024, | ||
hidden_channels=1024, | ||
num_layers=2, | ||
heads=4, | ||
norm='batch_norm', | ||
) | ||
|
||
model = GRetriever( | ||
llm=llm, | ||
gnn=gnn, | ||
mlp_out_channels=2048, | ||
) | ||
assert str(model) == ('GRetriever(\n' | ||
' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' | ||
' gnn=GAT(1024, 1024, num_layers=2),\n' | ||
')') | ||
|
||
x = torch.randn(10, 1024) | ||
edge_index = torch.tensor([ | ||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | ||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], | ||
]) | ||
edge_attr = torch.randn(edge_index.size(1), 1024) | ||
batch = torch.zeros(x.size(0), dtype=torch.long) | ||
|
||
question = ["Is PyG the best open-source GNN library?"] | ||
label = ["yes!"] | ||
|
||
# Test train: | ||
loss = model(question, x, edge_index, batch, label, edge_attr) | ||
assert loss >= 0 | ||
|
||
# Test inference: | ||
pred = model.inference(question, x, edge_index, batch, edge_attr) | ||
assert len(pred) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.nn.models import GAT | ||
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS | ||
from torch_geometric.utils import scatter | ||
|
||
|
||
class GRetriever(torch.nn.Module): | ||
r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented | ||
Generation for Textual Graph Understanding and Question Answering" | ||
<https://arxiv.org/abs/2402.07630>`_ paper. | ||
Args: | ||
llm (LLM): The LLM to use. | ||
gnn (torch.nn.Module): The GNN to use. | ||
use_lora (bool, optional): If set to :obj:`True`, will use LORA from | ||
:obj:`peft` for training the LLM, see | ||
`here <https://huggingface.co/docs/peft/en/index>`_ for details. | ||
(default: :obj:`False`) | ||
mlp_out_channels (int, optional): The size of each graph embedding | ||
after projection. (default: :obj:`4096`) | ||
.. warning:: | ||
This module has been tested with the following HuggingFace models | ||
* :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"` | ||
* :obj:`llm_to_use="google/gemma-7b"` | ||
and may not work with other models. See other models at `HuggingFace | ||
Models <https://huggingface.co/models>`_ and let us know if you | ||
encounter any issues. | ||
.. note:: | ||
For an example of using :class:`GRetriever`, see | ||
`examples/llm/g_retriever.py <https://github.com/pyg-team/ | ||
pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_. | ||
""" | ||
def __init__( | ||
self, | ||
llm: LLM, | ||
gnn: torch.nn.Module, | ||
use_lora: bool = False, | ||
gnn_to_use=GAT, | ||
mlp_out_channels: int = 4096, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.llm = llm | ||
self.gnn = gnn.to(self.llm.device) | ||
|
||
self.word_embedding = self.llm.word_embedding | ||
self.llm_generator = self.llm.llm | ||
if use_lora: | ||
from peft import ( | ||
LoraConfig, | ||
get_peft_model, | ||
prepare_model_for_kbit_training, | ||
) | ||
self.llm_generator = prepare_model_for_kbit_training( | ||
self.llm_generator) | ||
lora_r: int = 8 | ||
lora_alpha: int = 16 | ||
lora_dropout: float = 0.05 | ||
lora_target_modules = ['q_proj', 'v_proj'] | ||
config = LoraConfig( | ||
r=lora_r, | ||
lora_alpha=lora_alpha, | ||
target_modules=lora_target_modules, | ||
lora_dropout=lora_dropout, | ||
bias='none', | ||
task_type='CAUSAL_LM', | ||
) | ||
self.llm_generator = get_peft_model(self.llm_generator, config) | ||
|
||
mlp_hidden_channels = self.gnn.out_channels | ||
self.projector = torch.nn.Sequential( | ||
torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), | ||
torch.nn.Sigmoid(), | ||
torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), | ||
).to(self.llm.device) | ||
|
||
def encode( | ||
self, | ||
x: Tensor, | ||
edge_index: Tensor, | ||
batch: Tensor, | ||
edge_attr: Optional[Tensor], | ||
) -> Tensor: | ||
x = x.to(self.llm.device) | ||
edge_index = edge_index.to(self.llm.device) | ||
if edge_attr is not None: | ||
edge_attr = edge_attr.to(self.llm.device) | ||
batch = batch.to(self.llm.device) | ||
|
||
out = self.gnn(x, edge_index, edge_attr=edge_attr) | ||
return scatter(out, batch, dim=0, reduce='mean') | ||
|
||
def forward( | ||
self, | ||
question: List[str], | ||
x: Tensor, | ||
edge_index: Tensor, | ||
batch: Tensor, | ||
label: List[str], | ||
edge_attr: Optional[Tensor] = None, | ||
additional_text_context: Optional[List[str]] = None, | ||
): | ||
r"""The forward pass. | ||
Args: | ||
question (List[str]): The questions/prompts. | ||
x (torch.Tensor): The input node features. | ||
edge_index (torch.Tensor): The edge indices. | ||
batch (torch.Tensor): The batch vector | ||
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns | ||
each element to a specific example. | ||
label (List[str]): The answers/labels. | ||
edge_attr (torch.Tensor, optional): The edge features (if supported | ||
by the GNN). (default: :obj:`None`) | ||
additional_text_context (List[str], optional): Additional context | ||
to give to the LLM, such as textified knowledge graphs. | ||
(default: :obj:`None`) | ||
""" | ||
x = self.encode(x, edge_index, batch, edge_attr) | ||
x = self.projector(x) | ||
xs = x.split(x.size(0), dim=0) | ||
|
||
( | ||
inputs_embeds, | ||
attention_mask, | ||
label_input_ids, | ||
) = self.llm._get_embeds(question, additional_text_context, xs, label) | ||
|
||
with self.llm.autocast_context: | ||
outputs = self.llm_generator( | ||
inputs_embeds=inputs_embeds, | ||
attention_mask=attention_mask, | ||
return_dict=True, | ||
labels=label_input_ids, | ||
) | ||
|
||
return outputs.loss | ||
|
||
@torch.no_grad() | ||
def inference( | ||
self, | ||
question: List[str], | ||
x: Tensor, | ||
edge_index: Tensor, | ||
batch: Tensor, | ||
edge_attr: Optional[Tensor] = None, | ||
additional_text_context: Optional[List[str]] = None, | ||
max_out_tokens: Optional[int] = MAX_NEW_TOKENS, | ||
): | ||
r"""The inference pass. | ||
Args: | ||
question (List[str]): The questions/prompts. | ||
x (torch.Tensor): The input node features. | ||
edge_index (torch.Tensor): The edge indices. | ||
batch (torch.Tensor): The batch vector | ||
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns | ||
each element to a specific example. | ||
edge_attr (torch.Tensor, optional): The edge features (if supported | ||
by the GNN). (default: :obj:`None`) | ||
additional_text_context (List[str], optional): Additional context | ||
to give to the LLM, such as textified knowledge graphs. | ||
(default: :obj:`None`) | ||
max_out_tokens (int, optional): How many tokens for the LLM to | ||
generate. (default: :obj:`32`) | ||
""" | ||
x = self.encode(x, edge_index, batch, edge_attr) | ||
x = self.projector(x) | ||
xs = x.split(x.size(0), dim=0) | ||
|
||
inputs_embeds, attention_mask, _ = self.llm._get_embeds( | ||
question, additional_text_context, xs) | ||
|
||
bos_token = self.llm.tokenizer( | ||
BOS, | ||
add_special_tokens=False, | ||
).input_ids[0] | ||
|
||
with self.llm.autocast_context: | ||
outputs = self.llm_generator.generate( | ||
inputs_embeds=inputs_embeds, | ||
max_new_tokens=max_out_tokens, | ||
attention_mask=attention_mask, | ||
bos_token_id=bos_token, | ||
use_cache=True # Important to set! | ||
) | ||
|
||
return self.llm.tokenizer.batch_decode( | ||
outputs, | ||
skip_special_tokens=True, | ||
) | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}(\n' | ||
f' llm={self.llm},\n' | ||
f' gnn={self.gnn},\n' | ||
f')') |
Oops, something went wrong.