Skip to content

Commit

Permalink
Modularizing LLM (G-retriever) code (#9502)
Browse files Browse the repository at this point in the history
will merge this in when its ready to address
#9480 (comment)
halfway done modularizing

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
puririshi98 and pre-commit-ci[bot] authored Jul 12, 2024
1 parent ba25c51 commit 57da834
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 202 deletions.
117 changes: 15 additions & 102 deletions torch_geometric/nn/models/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
import torch.nn as nn

from torch_geometric.nn.models import GAT
from torch_geometric.nn.nlp.llm import (
EOS,
IGNORE_INDEX,
LLM,
MAX_NEW_TOKENS,
MAX_TXT_LEN,
)
from torch_geometric.nn.nlp.llm import LLM, MAX_NEW_TOKENS
from torch_geometric.utils import scatter


Expand Down Expand Up @@ -163,65 +157,17 @@ def forward(
additional_text_context (List[str], optional): Additional context
to give to the LLM, such as textified knowledge graphs.
"""
batch_size, questions, context, eos_user_tokens, \
bos_embeds, pad_embeds = self.llm_to_use._encode_inputs(question, additional_text_context) # noqa
# encode labels
labels = self.tokenizer(label, add_special_tokens=False)
# encode training specific special token
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)

# encode graphs
num_nodes_per_graph = ptr[1:] - ptr[:-1]
graph_embeds = self.encode_graphs(node_feat, edge_index, edge_attr,
batch)
graph_embeds = self.projector(graph_embeds)
batch_inputs_embeds = []
batch_attention_mask = []
batch_label_input_ids = []
num_nodes_per_graph = ptr[1:] - ptr[:-1]
for i in range(batch_size):
# Add bos & eos token
label_input_ids = labels.input_ids[
i][:MAX_NEW_TOKENS] + eos_tokens.input_ids
if additional_text_context is not None:
input_ids = context.input_ids[
i][:MAX_TXT_LEN] + questions.input_ids[
i] + eos_user_tokens.input_ids + label_input_ids
else:
input_ids = questions.input_ids[
i] + eos_user_tokens.input_ids + label_input_ids
inputs_embeds = self.word_embedding(
torch.tensor(input_ids).to(self.llm_device))
to_cat = [bos_embeds]
if num_nodes_per_graph[i] != 0:
to_cat.append(graph_embeds[i].unsqueeze(0))
to_cat.append(inputs_embeds)
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat],
dim=0)
batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.shape[0])
label_input_ids = [IGNORE_INDEX
] * (inputs_embeds.shape[0] -
len(label_input_ids)) + label_input_ids
batch_label_input_ids.append(label_input_ids)
graph_embeds = [
(embed.unsqueeze(0) if num_nodes_per_graph[i] != 0 else None)
for i, embed in enumerate(self.projector(graph_embeds))
]

# pad inputs_embeds
max_length = max([x.shape[0] for x in batch_inputs_embeds])
for i in range(batch_size):
pad_length = max_length - batch_inputs_embeds[i].shape[0]
batch_inputs_embeds[i] = torch.cat([
pad_embeds.repeat(pad_length, 1).to(self.llm_device),
batch_inputs_embeds[i].to(self.llm_device)
])
batch_attention_mask[i] = [0
] * pad_length + batch_attention_mask[i]
batch_label_input_ids[
i] = [IGNORE_INDEX] * pad_length + batch_label_input_ids[i]
inputs_embeds, attention_mask, label_input_ids = self.llm_to_use._get_embeds( # noqa
question, additional_text_context, graph_embeds, label)

inputs_embeds = torch.stack(batch_inputs_embeds,
dim=0).to(self.llm_device)
attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device)
label_input_ids = torch.tensor(batch_label_input_ids).to(
self.llm_device)
with self.llm_to_use.autocast_context:
outputs = self.llm_generator(
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -263,48 +209,15 @@ def inference(
max_out_tokens (int, optional): How many tokens for the LLM to
generate. (default: {32})
"""
batch_size, questions, context, eos_user_tokens, \
bos_embeds, pad_embeds = self.llm_to_use._encode_inputs(question, additional_text_context) # noqa
# encode graphs
num_nodes_per_graph = ptr[1:] - ptr[:-1]
graph_embeds = self.encode_graphs(node_feat, edge_index, edge_attr,
batch)
graph_embeds = self.projector(graph_embeds)

batch_inputs_embeds = []
batch_attention_mask = []
num_nodes_per_graph = ptr[1:] - ptr[:-1]
for i in range(batch_size):
# Add bos & eos token
if additional_text_context is not None:
input_ids = context.input_ids[
i][:MAX_TXT_LEN] + questions.input_ids[
i] + eos_user_tokens.input_ids
else:
input_ids = questions.input_ids[i] + eos_user_tokens.input_ids
inputs_embeds = self.word_embedding(
torch.tensor(input_ids).to(self.llm_device))
to_cat = [bos_embeds]
if num_nodes_per_graph[i] != 0:
to_cat.append(graph_embeds[i].unsqueeze(0))
to_cat.append(inputs_embeds)
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat],
dim=0)
batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.shape[0])

# pad inputs_embeds
max_length = max([x.shape[0] for x in batch_inputs_embeds])
for i in range(batch_size):
pad_length = max_length - batch_inputs_embeds[i].shape[0]
batch_inputs_embeds[i] = torch.cat(
[pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
batch_attention_mask[i] = [0
] * pad_length + batch_attention_mask[i]

inputs_embeds = torch.stack(batch_inputs_embeds,
dim=0).to(self.llm_device)
attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device)

graph_embeds = [
(embed.unsqueeze(0) if num_nodes_per_graph[i] != 0 else None)
for i, embed in enumerate(self.projector(graph_embeds))
]
inputs_embeds, attention_mask, _ = self.llm_to_use._get_embeds(
question, additional_text_context, graph_embeds)
with self.llm_to_use.autocast_context:
outputs = self.llm_generator.generate(
inputs_embeds=inputs_embeds,
Expand Down
202 changes: 102 additions & 100 deletions torch_geometric/nn/nlp/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from contextlib import nullcontext
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -117,6 +116,102 @@ def _encode_inputs(
return (batch_size, questions, context, eos_user_tokens, bos_embeds,
pad_embeds)

def _label_input_ids(self, i, label, eos_tokens):
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
label_input_ids += eos_tokens.input_ids # Add EOS token.
return label_input_ids

def _input_ids(self, i, context, question, eos_user_tokens):
input_ids: List[int] = []
if context is not None:
input_ids += context.input_ids[i][:MAX_TXT_LEN]
input_ids += question.input_ids[i]
input_ids += eos_user_tokens.input_ids
return input_ids

def _inputs_embeds(self, i, input_ids, bos_embeds, embedding=None):
inputs_embeds = self.word_embedding(
torch.tensor(input_ids, device=self.llm_device))

to_cat = [bos_embeds]
if embedding is not None:
to_cat.append(embedding[i])
to_cat.append(inputs_embeds)
inputs_embeds = torch.cat([i.to(self.llm_device) for i in to_cat],
dim=0)
return inputs_embeds

def _append_embeds(self, inputs_embeds, batch_inputs_embeds,
batch_attention_mask, label_input_ids=None,
batch_label_input_ids=None):
batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.size(0))
if label_input_ids is not None:
label_input_ids = [IGNORE_INDEX] * (
inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
batch_label_input_ids.append(label_input_ids)
return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids

def _pad_embeds(self, pad_embeds, batch_inputs_embeds,
batch_attention_mask, batch_label_input_ids=None):
max_length = max([x.size(0) for x in batch_inputs_embeds])
batch_size = len(batch_inputs_embeds)
for i in range(batch_size):
pad = max_length - batch_inputs_embeds[i].size(0)
batch_inputs_embeds[i] = torch.cat([
pad_embeds.repeat(pad, 1),
batch_inputs_embeds[i],
])
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
if batch_label_input_ids is not None:
batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
batch_label_input_ids[i])
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
attention_mask = torch.tensor(batch_attention_mask,
device=self.llm_device)
if batch_label_input_ids is not None:
label_input_ids = torch.tensor(batch_label_input_ids,
device=self.llm_device)
else:
label_input_ids = None
return inputs_embeds, attention_mask, label_input_ids

def _get_embeds(self, question, context=None, embedding=None, answer=None):
(batch_size, question, context, eos_user_tokens, bos_embeds,
pad_embeds) = self._encode_inputs(question, context)
if answer is not None:
label = self.tokenizer(answer, add_special_tokens=False)
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
batch_label_input_ids = []
else:
batch_label_input_ids = None

batch_inputs_embeds = []
batch_attention_mask = []
if answer is not None:
batch_label_input_ids = []
else:
batch_label_input_ids = None
for i in range(batch_size):
input_ids = self._input_ids(i, context, question, eos_user_tokens)
if answer is not None:
label_input_ids = self._label_input_ids(i, label, eos_tokens)
input_ids += label_input_ids
else:
label_input_ids = None

inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds,
embedding)

batch_inputs_embeds, batch_attention_mask, batch_label_input_ids = self._append_embeds( # noqa
inputs_embeds, batch_inputs_embeds, batch_attention_mask,
label_input_ids, batch_label_input_ids)

inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(
pad_embeds, batch_inputs_embeds, batch_attention_mask,
batch_label_input_ids)
return inputs_embeds, attention_mask, label_input_ids

def forward(
self,
question: List[str],
Expand All @@ -133,65 +228,11 @@ def forward(
LLM, such as textified knowledge graphs. (default: :obj:`None`)
embedding (list[torch.Tensor], optional): RAG embedding
tensors, *i.e.* the embedded form of :obj:`context`. Either
:obj:`context` or :obj:`rag_embeddings` should be used, not
:obj:`context` or :obj:`embedding` should be used, not
both. (default: :obj:`None`)
"""
if context is not None and embedding is not None:
warnings.warn("Using both 'context' and 'embedding' is a waste of "
"compute and memory")

(batch_size, question, context, eos_user_tokens, bos_embeds,
pad_embeds) = self._encode_inputs(question, context)

label = self.tokenizer(answer, add_special_tokens=False)
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)

batch_inputs_embeds = []
batch_attention_mask = []
batch_label_input_ids = []
for i in range(batch_size):
label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]
label_input_ids += eos_tokens.input_ids # Add EOS token.

input_ids: List[int] = []
if context is not None:
input_ids += context.input_ids[i][:MAX_TXT_LEN]
input_ids += question.input_ids[i]
input_ids += eos_user_tokens.input_ids
input_ids += label_input_ids

inputs_embeds = self.word_embedding(
torch.tensor(input_ids, device=self.llm_device))

to_cat = [bos_embeds]
if embedding is not None:
to_cat.append(embedding[i])
to_cat.append(inputs_embeds)
inputs_embeds = torch.cat(to_cat, dim=0)

batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.size(0))
label_input_ids = [IGNORE_INDEX] * (
inputs_embeds.size(0) - len(label_input_ids)) + label_input_ids
batch_label_input_ids.append(label_input_ids)

# Pad input embeddings:
max_length = max([x.size(0) for x in batch_inputs_embeds])
for i in range(batch_size):
pad = max_length - batch_inputs_embeds[i].size(0)
batch_inputs_embeds[i] = torch.cat([
pad_embeds.repeat(pad, 1),
batch_inputs_embeds[i],
])
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]
batch_label_input_ids[i] = ([IGNORE_INDEX] * pad +
batch_label_input_ids[i])

inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
attention_mask = torch.tensor(batch_attention_mask,
device=self.llm_device)
label_input_ids = torch.tensor(batch_label_input_ids,
device=self.llm_device)
inputs_embeds, attention_mask, label_input_ids = self._get_embeds(
question, context, embedding, answer)

with self.autocast_context:
outputs = self.llm(
Expand Down Expand Up @@ -219,52 +260,13 @@ def inference(
LLM, such as textified knowledge graphs. (default: :obj:`None`)
embedding (list[torch.Tensor], optional): RAG embedding
tensors, *i.e.* the embedded form of :obj:`context`. Either
:obj:`context` or :obj:`rag_embeddings` should be used, not
:obj:`context` or :obj:`embedding` should be used, not
both. (default: :obj:`None`)
max_tokens (int, optional): How many tokens for the LLM to
generate. (default: :obj:`32`)
"""
if context is not None and embedding is not None:
warnings.warn("Using both 'context' and 'embedding' is a waste of "
"compute and memory")

(batch_size, question, context, eos_user_tokens, bos_embeds,
pad_embeds) = self._encode_inputs(question, context)

batch_inputs_embeds = []
batch_attention_mask = []
for i in range(batch_size):
input_ids: List[int] = []
if context is not None:
input_ids = context.input_ids[i][:MAX_TXT_LEN]
input_ids += question.input_ids[i]
input_ids += eos_user_tokens.input_ids

inputs_embeds = self.word_embedding(
torch.tensor(input_ids, device=self.llm_device))

to_cat = [bos_embeds]
if embedding is not None:
to_cat.append(embedding[i])
to_cat.append(inputs_embeds)
inputs_embeds = torch.cat(to_cat, dim=0)

batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.size(0))

# Pad input embeddings:
max_length = max([x.size(0) for x in batch_inputs_embeds])
for i in range(batch_size):
pad = max_length - batch_inputs_embeds[i].size(0)
batch_inputs_embeds[i] = torch.cat([
pad_embeds.repeat(pad, 1),
batch_inputs_embeds[i],
])
batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]

inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)
attention_mask = torch.tensor(batch_attention_mask,
device=self.llm_device)
inputs_embeds, attention_mask, _ = self._get_embeds(
question, context, embedding)

bos_token = self.tokenizer(
BOS,
Expand Down

0 comments on commit 57da834

Please sign in to comment.