diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f8bb52a74d3..e6e757d53f6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480)) - Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) - Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632)) - Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594)) diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py new file mode 100644 index 000000000000..899e70730cc9 --- /dev/null +++ b/test/nn/models/test_g_retriever.py @@ -0,0 +1,53 @@ +import torch + +from torch_geometric.nn import GAT, GRetriever +from torch_geometric.nn.nlp import LLM +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_g_retriever() -> None: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.float16, + ) + + gnn = GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=1024, + num_layers=2, + heads=4, + norm='batch_norm', + ) + + model = GRetriever( + llm=llm, + gnn=gnn, + mlp_out_channels=2048, + ) + assert str(model) == ('GRetriever(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' gnn=GAT(1024, 1024, num_layers=2),\n' + ')') + + x = torch.randn(10, 1024) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 1024) + batch = torch.zeros(x.size(0), dtype=torch.long) + + question = ["Is PyG the best open-source GNN library?"] + label = ["yes!"] + + # Test train: + loss = model(question, x, edge_index, batch, label, edge_attr) + assert loss >= 0 + + # Test inference: + pred = model.inference(question, x, edge_index, batch, edge_attr) + assert len(pred) == 1 diff --git a/test/nn/nlp/test_llm.py b/test/nn/nlp/test_llm.py index 03fa29622f5b..c3f5b3b835ca 100644 --- a/test/nn/nlp/test_llm.py +++ b/test/nn/nlp/test_llm.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torch_geometric.nn.nlp.llm import LLM +from torch_geometric.nn.nlp import LLM from torch_geometric.testing import onlyFullTest, withPackage @@ -12,10 +12,11 @@ def test_llm() -> None: answer = ["yes!"] model = LLM( - model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', num_params=1, dtype=torch.float16, ) + assert str(model) == 'LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1)' loss = model(question, answer) assert isinstance(loss, Tensor) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 334970da5c62..7cfadf0143b2 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -28,6 +28,7 @@ from .pmlp import PMLP from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet +from .g_retriever import GRetriever # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, @@ -75,4 +76,5 @@ 'PMLP', 'NeuralFingerprint', 'ViSNet', + 'GRetriever', ] diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py new file mode 100644 index 000000000000..43075196f67b --- /dev/null +++ b/torch_geometric/nn/models/g_retriever.py @@ -0,0 +1,205 @@ +from typing import List, Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.models import GAT +from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS +from torch_geometric.utils import scatter + + +class GRetriever(torch.nn.Module): + r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented + Generation for Textual Graph Understanding and Question Answering" + `_ paper. + + Args: + llm (LLM): The LLM to use. + gnn (torch.nn.Module): The GNN to use. + use_lora (bool, optional): If set to :obj:`True`, will use LORA from + :obj:`peft` for training the LLM, see + `here `_ for details. + (default: :obj:`False`) + mlp_out_channels (int, optional): The size of each graph embedding + after projection. (default: :obj:`4096`) + + .. warning:: + This module has been tested with the following HuggingFace models + + * :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"` + * :obj:`llm_to_use="google/gemma-7b"` + + and may not work with other models. See other models at `HuggingFace + Models `_ and let us know if you + encounter any issues. + + .. note:: + For an example of using :class:`GRetriever`, see + `examples/llm/g_retriever.py `_. + """ + def __init__( + self, + llm: LLM, + gnn: torch.nn.Module, + use_lora: bool = False, + gnn_to_use=GAT, + mlp_out_channels: int = 4096, + ) -> None: + super().__init__() + + self.llm = llm + self.gnn = gnn.to(self.llm.device) + + self.word_embedding = self.llm.word_embedding + self.llm_generator = self.llm.llm + if use_lora: + from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + self.llm_generator = prepare_model_for_kbit_training( + self.llm_generator) + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules = ['q_proj', 'v_proj'] + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias='none', + task_type='CAUSAL_LM', + ) + self.llm_generator = get_peft_model(self.llm_generator, config) + + mlp_hidden_channels = self.gnn.out_channels + self.projector = torch.nn.Sequential( + torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), + torch.nn.Sigmoid(), + torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), + ).to(self.llm.device) + + def encode( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + ) -> Tensor: + x = x.to(self.llm.device) + edge_index = edge_index.to(self.llm.device) + if edge_attr is not None: + edge_attr = edge_attr.to(self.llm.device) + batch = batch.to(self.llm.device) + + out = self.gnn(x, edge_index, edge_attr=edge_attr) + return scatter(out, batch, dim=0, reduce='mean') + + def forward( + self, + question: List[str], + x: Tensor, + edge_index: Tensor, + batch: Tensor, + label: List[str], + edge_attr: Optional[Tensor] = None, + additional_text_context: Optional[List[str]] = None, + ): + r"""The forward pass. + + Args: + question (List[str]): The questions/prompts. + x (torch.Tensor): The input node features. + edge_index (torch.Tensor): The edge indices. + batch (torch.Tensor): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each element to a specific example. + label (List[str]): The answers/labels. + edge_attr (torch.Tensor, optional): The edge features (if supported + by the GNN). (default: :obj:`None`) + additional_text_context (List[str], optional): Additional context + to give to the LLM, such as textified knowledge graphs. + (default: :obj:`None`) + """ + x = self.encode(x, edge_index, batch, edge_attr) + x = self.projector(x) + xs = x.split(x.size(0), dim=0) + + ( + inputs_embeds, + attention_mask, + label_input_ids, + ) = self.llm._get_embeds(question, additional_text_context, xs, label) + + with self.llm.autocast_context: + outputs = self.llm_generator( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=label_input_ids, + ) + + return outputs.loss + + @torch.no_grad() + def inference( + self, + question: List[str], + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor] = None, + additional_text_context: Optional[List[str]] = None, + max_out_tokens: Optional[int] = MAX_NEW_TOKENS, + ): + r"""The inference pass. + + Args: + question (List[str]): The questions/prompts. + x (torch.Tensor): The input node features. + edge_index (torch.Tensor): The edge indices. + batch (torch.Tensor): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each element to a specific example. + edge_attr (torch.Tensor, optional): The edge features (if supported + by the GNN). (default: :obj:`None`) + additional_text_context (List[str], optional): Additional context + to give to the LLM, such as textified knowledge graphs. + (default: :obj:`None`) + max_out_tokens (int, optional): How many tokens for the LLM to + generate. (default: :obj:`32`) + """ + x = self.encode(x, edge_index, batch, edge_attr) + x = self.projector(x) + xs = x.split(x.size(0), dim=0) + + inputs_embeds, attention_mask, _ = self.llm._get_embeds( + question, additional_text_context, xs) + + bos_token = self.llm.tokenizer( + BOS, + add_special_tokens=False, + ).input_ids[0] + + with self.llm.autocast_context: + outputs = self.llm_generator.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_out_tokens, + attention_mask=attention_mask, + bos_token_id=bos_token, + use_cache=True # Important to set! + ) + + return self.llm.tokenizer.batch_decode( + outputs, + skip_special_tokens=True, + ) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' llm={self.llm},\n' + f' gnn={self.gnn},\n' + f')') diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index 53b9c68236da..1b8cc0423732 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -1,10 +1,14 @@ -import warnings from contextlib import nullcontext from typing import Any, Dict, List, Optional import torch from torch import Tensor +try: + from transformers.tokenization_utils_base import BatchEncoding +except ImportError: + BatchEncoding = Dict + BOS = '[INST]' EOS_USER = '[/INST]' EOS = '[/s]' @@ -61,23 +65,16 @@ def __init__( ) -> None: super().__init__() - from transformers import AutoModelForCausalLM, AutoTokenizer + self.model_name = model_name - if model_name == 'llama2-7b': - pretty_model_name = 'LLAMA2' - model_name = 'meta-llama/Llama-2-7b-chat-hf' - elif model_name == 'gemma': - pretty_model_name = 'GEMMA' - model_name = 'google/gemma-7b' - else: - pretty_model_name = model_name + from transformers import AutoModelForCausalLM, AutoTokenizer # A rough heuristic on GPU memory requirements, e.g., we found that # LLAMA2 (7B parameters) fits on a 85GB GPU. required_memory = 85 * num_params / 7 kwargs = get_llm_kwargs(required_memory, dtype) - print(f"Setting up '{pretty_model_name}' with configuration: {kwargs}") + print(f"Setting up '{model_name}' with configuration: {kwargs}") self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=False, @@ -88,17 +85,17 @@ def __init__( self.word_embedding = self.llm.model.get_input_embeddings() if 'max_memory' not in kwargs: # Pure CPU: - self.llm_device = torch.device('cpu') + self.device = torch.device('cpu') self.autocast_context = nullcontext() else: - self.llm_device = self.llm.device + self.device = self.llm.device self.autocast_context = torch.cuda.amp.autocast(dtype=dtype) def _encode_inputs( self, question: List[str], context: Optional[List[str]] = None, - ) -> None: + ) -> tuple: batch_size = len(question) questions = self.tokenizer(question, add_special_tokens=False) if context is not None: @@ -109,14 +106,144 @@ def _encode_inputs( BOS, add_special_tokens=False, return_tensors='pt', - ).input_ids[0].to(self.llm_device) + ).input_ids[0].to(self.device) bos_embeds = self.word_embedding(bos_token) pad_token = torch.tensor(self.tokenizer.pad_token_id, - device=self.llm_device) + device=self.device) pad_embeds = self.word_embedding(pad_token).unsqueeze(0) return (batch_size, questions, context, eos_user_tokens, bos_embeds, pad_embeds) + def _label_input_ids( + self, + i: int, + label: BatchEncoding, + eos_tokens: BatchEncoding, + ) -> List[int]: + label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS] + label_input_ids = label_input_ids + eos_tokens.input_ids + return label_input_ids + + def _input_ids( + self, + i: int, + context: BatchEncoding, + question: BatchEncoding, + eos_user_tokens: BatchEncoding, + ) -> List[int]: + input_ids: List[int] = [] + if context is not None: + input_ids += context.input_ids[i][:MAX_TXT_LEN] + input_ids += question.input_ids[i] + input_ids += eos_user_tokens.input_ids + return input_ids + + def _inputs_embeds( + self, + i: int, + input_ids: List[int], + bos_embeds: Tensor, + embedding: Optional[List[Tensor]] = None, + ) -> Tensor: + inputs_embeds = self.word_embedding( + torch.tensor(input_ids, device=self.device)) + + to_cat = [bos_embeds] + if embedding is not None and embedding[i] is not None: + to_cat.append(embedding[i]) + to_cat.append(inputs_embeds) + return torch.cat(to_cat, dim=0).to(self.device) + + def _append_embeds( + self, + inputs_embeds: Tensor, + batch_inputs_embeds: List[Tensor], + batch_attention_mask: List[List[int]], + label_input_ids: List[int] = None, + batch_label_input_ids: Optional[List[List[int]]] = None, + ) -> tuple: + batch_inputs_embeds.append(inputs_embeds) + batch_attention_mask.append([1] * inputs_embeds.size(0)) + if label_input_ids is not None: + pad = inputs_embeds.size(0) - len(label_input_ids) + label_input_ids = [IGNORE_INDEX] * pad + label_input_ids + batch_label_input_ids.append(label_input_ids) + return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids + + def _pad_embeds( + self, + pad_embeds: Tensor, + batch_inputs_embeds: List[Tensor], + batch_attention_mask: List[List[int]], + batch_label_input_ids: Optional[List[List[int]]] = None, + ) -> tuple: + max_length = max([x.size(0) for x in batch_inputs_embeds]) + batch_size = len(batch_inputs_embeds) + for i in range(batch_size): + pad = max_length - batch_inputs_embeds[i].size(0) + batch_inputs_embeds[i] = torch.cat([ + pad_embeds.repeat(pad, 1), + batch_inputs_embeds[i], + ]) + batch_attention_mask[i] = [0] * pad + batch_attention_mask[i] + if batch_label_input_ids is not None: + tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i] + batch_label_input_ids[i] = tmp + inputs_embeds = torch.stack(batch_inputs_embeds, dim=0) + attention_mask = torch.tensor(batch_attention_mask, device=self.device) + label_input_ids = None + if batch_label_input_ids is not None: + label_input_ids = torch.tensor(batch_label_input_ids, + device=self.device) + return inputs_embeds, attention_mask, label_input_ids + + def _get_embeds( + self, + question: List[str], + context: Optional[List[str]] = None, + embedding: Optional[List[Tensor]] = None, + answer: Optional[List[str]] = None, + ) -> tuple: + (batch_size, question, context, eos_user_tokens, bos_embeds, + pad_embeds) = self._encode_inputs(question, context) + + batch_label_input_ids = None + if answer is not None: + label = self.tokenizer(answer, add_special_tokens=False) + eos_tokens = self.tokenizer(EOS, add_special_tokens=False) + batch_label_input_ids = [] + + batch_inputs_embeds = [] + batch_attention_mask = [] + for i in range(batch_size): + input_ids = self._input_ids(i, context, question, eos_user_tokens) + if answer is not None: + label_input_ids = self._label_input_ids(i, label, eos_tokens) + input_ids += label_input_ids + else: + label_input_ids = None + + inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds, + embedding) + + ( + batch_inputs_embeds, + batch_attention_mask, + batch_label_input_ids, + ) = self._append_embeds( + inputs_embeds, + batch_inputs_embeds, + batch_attention_mask, + label_input_ids, + batch_label_input_ids, + ) + + inputs_embeds, attention_mask, label_input_ids = self._pad_embeds( + pad_embeds, batch_inputs_embeds, batch_attention_mask, + batch_label_input_ids) + + return inputs_embeds, attention_mask, label_input_ids + def forward( self, question: List[str], @@ -133,65 +260,11 @@ def forward( LLM, such as textified knowledge graphs. (default: :obj:`None`) embedding (list[torch.Tensor], optional): RAG embedding tensors, *i.e.* the embedded form of :obj:`context`. Either - :obj:`context` or :obj:`rag_embeddings` should be used, not + :obj:`context` or :obj:`embedding` should be used, not both. (default: :obj:`None`) """ - if context is not None and embedding is not None: - warnings.warn("Using both 'context' and 'embedding' is a waste of " - "compute and memory") - - (batch_size, question, context, eos_user_tokens, bos_embeds, - pad_embeds) = self._encode_inputs(question, context) - - label = self.tokenizer(answer, add_special_tokens=False) - eos_tokens = self.tokenizer(EOS, add_special_tokens=False) - - batch_inputs_embeds = [] - batch_attention_mask = [] - batch_label_input_ids = [] - for i in range(batch_size): - label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS] - label_input_ids += eos_tokens.input_ids # Add EOS token. - - input_ids: List[int] = [] - if context is not None: - input_ids += context.input_ids[i][:MAX_TXT_LEN] - input_ids += question.input_ids[i] - input_ids += eos_user_tokens.input_ids - input_ids += label_input_ids - - inputs_embeds = self.word_embedding( - torch.tensor(input_ids, device=self.llm_device)) - - to_cat = [bos_embeds] - if embedding is not None: - to_cat.append(embedding[i]) - to_cat.append(inputs_embeds) - inputs_embeds = torch.cat(to_cat, dim=0) - - batch_inputs_embeds.append(inputs_embeds) - batch_attention_mask.append([1] * inputs_embeds.size(0)) - label_input_ids = [IGNORE_INDEX] * ( - inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids - batch_label_input_ids.append(label_input_ids) - - # Pad input embeddings: - max_length = max([x.size(0) for x in batch_inputs_embeds]) - for i in range(batch_size): - pad = max_length - batch_inputs_embeds[i].size(0) - batch_inputs_embeds[i] = torch.cat([ - pad_embeds.repeat(pad, 1), - batch_inputs_embeds[i], - ]) - batch_attention_mask[i] = [0] * pad + batch_attention_mask[i] - batch_label_input_ids[i] = ([IGNORE_INDEX] * pad + - batch_label_input_ids[i]) - - inputs_embeds = torch.stack(batch_inputs_embeds, dim=0) - attention_mask = torch.tensor(batch_attention_mask, - device=self.llm_device) - label_input_ids = torch.tensor(batch_label_input_ids, - device=self.llm_device) + inputs_embeds, attention_mask, label_input_ids = self._get_embeds( + question, context, embedding, answer) with self.autocast_context: outputs = self.llm( @@ -219,52 +292,13 @@ def inference( LLM, such as textified knowledge graphs. (default: :obj:`None`) embedding (list[torch.Tensor], optional): RAG embedding tensors, *i.e.* the embedded form of :obj:`context`. Either - :obj:`context` or :obj:`rag_embeddings` should be used, not + :obj:`context` or :obj:`embedding` should be used, not both. (default: :obj:`None`) max_tokens (int, optional): How many tokens for the LLM to generate. (default: :obj:`32`) """ - if context is not None and embedding is not None: - warnings.warn("Using both 'context' and 'embedding' is a waste of " - "compute and memory") - - (batch_size, question, context, eos_user_tokens, bos_embeds, - pad_embeds) = self._encode_inputs(question, context) - - batch_inputs_embeds = [] - batch_attention_mask = [] - for i in range(batch_size): - input_ids: List[int] = [] - if context is not None: - input_ids = context.input_ids[i][:MAX_TXT_LEN] - input_ids += question.input_ids[i] - input_ids += eos_user_tokens.input_ids - - inputs_embeds = self.word_embedding( - torch.tensor(input_ids, device=self.llm_device)) - - to_cat = [bos_embeds] - if embedding is not None: - to_cat.append(embedding[i]) - to_cat.append(inputs_embeds) - inputs_embeds = torch.cat(to_cat, dim=0) - - batch_inputs_embeds.append(inputs_embeds) - batch_attention_mask.append([1] * inputs_embeds.size(0)) - - # Pad input embeddings: - max_length = max([x.size(0) for x in batch_inputs_embeds]) - for i in range(batch_size): - pad = max_length - batch_inputs_embeds[i].size(0) - batch_inputs_embeds[i] = torch.cat([ - pad_embeds.repeat(pad, 1), - batch_inputs_embeds[i], - ]) - batch_attention_mask[i] = [0] * pad + batch_attention_mask[i] - - inputs_embeds = torch.stack(batch_inputs_embeds, dim=0) - attention_mask = torch.tensor(batch_attention_mask, - device=self.llm_device) + inputs_embeds, attention_mask, _ = self._get_embeds( + question, context, embedding) bos_token = self.tokenizer( BOS, @@ -281,3 +315,6 @@ def inference( ) return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.model_name})'