Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpt2 and hfroberta #45

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
bdeb09b
run and sometimes fail by myself
morioka Jan 9, 2021
3beeedc
Create README_morioka.md
morioka Jan 9, 2021
5e33214
fix loading RoBERTa
morioka Jan 12, 2021
7669dd2
add TODO
morioka Jan 12, 2021
cdaefb8
add memo regarding RoBERTa
morioka Jan 12, 2021
8bfa379
fix options typo
morioka Jan 17, 2021
1f72512
add HF RoBERTa connector (work in progress)
morioka Jan 26, 2021
65f15f0
add HF RoBERTa connector (work in progress)
morioka Jan 26, 2021
ab48672
fix
morioka Jan 26, 2021
29d4d87
avoid div0
morioka Jan 26, 2021
a05717e
fix
morioka Jan 26, 2021
79ffbfd
fix
morioka Jan 27, 2021
aecca5d
turn-on bert-base case
morioka Jan 27, 2021
5df078b
fix
morioka Jan 27, 2021
53a4806
clean
morioka Jan 27, 2021
3ca10c1
fix byte-level BPE issue (an encoded SPC char)
morioka Jan 27, 2021
63a2eb7
for experiment
morioka Jan 27, 2021
06b4066
for experiment
morioka Jan 27, 2021
06ba24b
clean
morioka Jan 27, 2021
b1d7fc8
add memo
morioka Jan 27, 2021
1ea76aa
fix hfroberta_connector
morioka Jan 28, 2021
59dbc47
fix last_results.csv
morioka Jan 28, 2021
01aa1d8
clean
morioka Jan 29, 2021
43c764e
add hfroberta args to options
morioka Jan 29, 2021
392c2b2
fix model dir
morioka Feb 1, 2021
1cd6b86
fix gpt config
morioka Feb 1, 2021
2aed302
hide gpt config
morioka Feb 1, 2021
6163e1a
fix hfroberta get_id
morioka Feb 1, 2021
1d4c38a
clean hfroberta get_id
morioka Feb 1, 2021
71c675e
clean hfroberta_connector
morioka Feb 1, 2021
6cd41c1
fix
morioka Feb 2, 2021
726cf80
fix BOS, EOS tokens for hfRoBERTa
morioka Feb 14, 2021
48a5a2a
clean __get_input_tensors of hfRoBERTA connector
morioka Feb 14, 2021
6eb12ad
add GPT-2 support
morioka Feb 15, 2021
afcf148
fix hfroberta
morioka Feb 15, 2021
156b888
update README_morioka.md
morioka Feb 15, 2021
772dc08
update README.md
morioka Feb 15, 2021
ceec5f6
quick fix not masked_sentences but masked_sentence issue in TREx dataset
morioka Feb 15, 2021
527754e
clean
morioka Feb 16, 2021
719f217
fix
morioka Feb 18, 2021
9b98336
add switch for Negated-LAMA
morioka Feb 18, 2021
831652a
fix sub_label issue in lowercase
morioka Feb 24, 2021
2f6351a
fix handling elmo
morioka Aug 17, 2021
0f6204c
clean
morioka Aug 17, 2021
99a2590
clean
morioka Aug 17, 2021
99ae795
remove fix #30 and #31, tentatively
morioka Aug 18, 2021
7a3efee
remove fix #30 and #31, tentatively
morioka Aug 18, 2021
cf70e4a
reactivate fix lowercase_samples() issue
morioka Aug 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LAMA exposes a transparent and unique interface to use:
- BERT (Devlin et al., 2018)
- ELMo (Peters et al., 2018)
- GPT (Radford et al., 2018)
- GPT-2 (Radford et al., 2019)
- RoBERTa (Liu et al., 2019)

Actually, LAMA is also a beautiful animal.
Expand Down Expand Up @@ -185,13 +186,19 @@ BERT pretrained models can be loaded both: (i) passing the name of the model and
* __--bert-vocab-name/--bvn__ : name of vocabulary used to pre-train the BERT model (default = 'vocab.txt')


### RoBERTa
### RoBERTa (Fairseq)

* __--roberta-model-dir/--rmd__ : directory that contains the RoBERTa pre-trained model and the vocabulary (__REQUIRED__)
* __--roberta-model-name/--rmn__ : name of the RoBERTa pre-trained model (default = 'model.pt')
* __--roberta-vocab-name/--rvn__ : name of vocabulary used to pre-train the RoBERTa model (default = 'dict.txt')


### RoBERTa (HuggingFace)

* __--hfroberta-model-dir/--hmd__ : directory that contains the HuggingFace RoBERTa pre-trained model and the vocabulary (__REQUIRED__)
* __--hfroberta-model-name/--hmn__ : name of the HuggingFace RoBERTa pre-trained model (default = 'roberta-base')


### ELMo

* __--elmo-model-dir/--emd__ : directory that contains the ELMo pre-trained model and the vocabulary (__REQUIRED__)
Expand All @@ -211,6 +218,12 @@ BERT pretrained models can be loaded both: (i) passing the name of the model and
* __--gpt-model-name/--gmn__ : name of the gpt pre-trained model (default = 'openai-gpt')


### GPT-2

* __--gpt2-model-dir/--g2d__ : directory that contains the gpt2 pre-trained model and the vocabulary (__REQUIRED__)
* __--gpt2-model-name/--g2n__ : name of the gpt2 pre-trained model (default = 'gpt2')


## Evaluate Language Model(s) Generation

options:
Expand Down
37 changes: 36 additions & 1 deletion download_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ if [[ ! -f gpt/openai-gpt/config.json ]]; then
cd ../..
fi

echo "GPT2"
if [[ ! -f gpt/gpt2/config.json ]]; then
rm -rf 'gpt/gpt2'
mkdir -p 'gpt/gpt2'
cd 'gpt/gpt2'
wget 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json' -O vocab.json
wget 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt' -O merges.txt
wget -c 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin' -O 'pytorch_model.bin'
wget -c 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json' -O 'config.json'
cd ../..
fi

echo "BERT BASE LOWERCASED"
if [[ ! -f bert/uncased_L-12_H-768_A-12/bert_config.json ]]; then
mkdir -p 'bert'
Expand Down Expand Up @@ -131,13 +143,36 @@ if [[ ! -f bert/cased_L-24_H-1024_A-16/bert_config.json ]]; then
cd ../../
fi

echo "RoBERTa"
if [[ ! -f roberta/roberta.base/dict.txt ]]; then
rm -rf 'roberta/roberta.base'
mkdir -p 'roberta/roberta.base'
cd 'roberta'
wget -c 'https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz'
tar -xzf roberta.base.tar.gz
rm roberta.base.tar.gz
cd ..
fi

echo "HuggingFace RoBERTa"
if [[ ! -f roberta/roberta-base/config.json ]]; then
rm -rf 'roberta/roberta-base'
mkdir -p 'roberta/roberta-base'
cd 'roberta/roberta-base'
wget 'https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json' -O vocab.json
wget 'https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt' -O merges.txt
wget -c 'https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin' -O 'pytorch_model.bin'
wget -c 'https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json' -O 'config.json'
cd ../..
fi


cd "$ROOD_DIR"
echo 'Building common vocab'
if [ ! -f "$DST_DIR/common_vocab_cased.txt" ]; then
python lama/vocab_intersection.py
else
echo 'Already exists. Run to re-build:'
echo 'python util_KB_completion.py'
echo 'python lama/vocab_intersection.py'
fi

6 changes: 5 additions & 1 deletion lama/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .gpt_connector import GPT
from .transformerxl_connector import TransformerXL
from .roberta_connector import Roberta
from .hfroberta_connector import HfRoberta
from .gpt2_connector import GPT2


def build_model_by_name(lm, args, verbose=True):
Expand All @@ -22,7 +24,9 @@ def build_model_by_name(lm, args, verbose=True):
bert=Bert,
gpt=GPT,
transformerxl=TransformerXL,
roberta=Roberta
roberta=Roberta,
hfroberta=HfRoberta,
gpt2=GPT2
)
if lm not in MODEL_NAME_TO_CLASS:
raise ValueError("Unrecognized Language Model: %s." % lm)
Expand Down
25 changes: 20 additions & 5 deletions lama/modules/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,24 @@
BERT_CLS = "[CLS]"
BERT_SEP = "[SEP]"
BERT_PAD = "[PAD]"

ELMO_UNK = "<UNK>"
ELMO_START_SENTENCE = "<S>"
ELMO_END_SENTENCE = "</S>"

OPENAI_UNK = "<unk>"
OPENAI_EOS = "<eos>"
ROBERTA_MASK = "<mask>"
ROBERTA_START_SENTENCE = "<s>"
ROBERTA_END_SENTENCE = "</s>"
ROBERTA_VOCAB_SIZE = 50266

ROBERTA_MASK = "<mask>" # MASK for fairseq/huggingface RoBERTa
ROBERTA_VOCAB_SIZE = 50266 # for fairseq RoBERTa

ROBERTA_START_SENTENCE = "<s>" # BOS, CLS for huggingface RoBERTa
ROBERTA_END_SENTENCE = "</s>" # EOS, SEP for huggingface RoBERTa
ROBERTA_UNK = "<unk>" # UNK for huggingface RoBERTa
ROBERTA_PAD = "<pad>" # PAD for huggingface RoBERTa

GPT2_EOS = "<|endoftext|>" # BOS, EOS, UNK, PAD for GPT2


SPECIAL_SYMBOLS = [
MASK,
Expand All @@ -32,7 +41,13 @@
ELMO_START_SENTENCE,
ELMO_END_SENTENCE,
OPENAI_UNK,
OPENAI_EOS
OPENAI_EOS,
ROBERTA_MASK,
# ROBERTA_UNK,
ROBERTA_PAD,
ROBERTA_START_SENTENCE,
ROBERTA_END_SENTENCE,
GPT2_EOS
]

SPACE_NORMALIZER = re.compile(r"\s+")
Expand Down
167 changes: 167 additions & 0 deletions lama/modules/gpt2_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from lama.modules.base_connector import *


class GPT2(Base_Connector):

def __init__(self, args):
super().__init__()

if args.gpt2_model_dir is not None:
# load GPT2 model from file
gpt_model_name = str(args.gpt2_model_dir) + "/"
dict_file = gpt_model_name
print("loading GPT2 model from {}".format(gpt_model_name))
else:
# load GPT2 model from huggingface cache
gpt_model_name = args.gpt2_model_name
dict_file = gpt_model_name

# Load pre-trained model tokenizer (vocabulary)
self.tokenizer = GPT2Tokenizer.from_pretrained(dict_file)

# GPT uses different way to represent BPE then BERT. Namely, the
# final suffixes are indicated with </w> suffix, while pieces that must
# be followed are written as is. In BERT the prefixes are written as is
# while the parts that must follow (not be followed!) have '##' prefix.
# There is no one-to-one coversion. But at least we may make pieces that
# may form a full word look the same.
# Note that we should be very careful now,
# tokenizer.convert_tokens_to_ids won't work with our vocabulary.

def convert_word(word):
if word == GPT2_EOS:
return word

if word.startswith('Ġ'): # the token starts with a whitespace
return word[1:]

return f'_{word}_' # the token not start with a white space.
# may be not a head of a word,
# or may be a head of a sentence.

_, gpt_vocab = zip(*sorted(self.tokenizer.decoder.items()))
self.vocab = [convert_word(word) for word in gpt_vocab]
self._init_inverse_vocab()

# Load pre-trained model (weights)
self.gpt_model = GPT2LMHeadModel.from_pretrained(gpt_model_name)
self.gpt_model.eval()
# print(self.gpt_model.config)

# Sanity check.
assert len(self.vocab) == self.gpt_model.config.vocab_size
#assert 0 == self.gpt_model.config.n_special

self.eos_id = self.gpt_model.config.eos_token_id
self.pad_id = self.gpt_model.config.eos_token_id
self.unk_id = self.gpt_model.config.eos_token_id
self.bos_id = self.gpt_model.config.bos_token_id
self.model_vocab = self.vocab

def _cuda(self):
self.gpt_model.cuda()

def get_id(self, string):
indexed_string = self.tokenizer.encode(f'a {string}')[1:]
return indexed_string

def __get_input_tensors(self, sentence_list):
"""Concatenates, tokenize and converts a sentences to model inputs.

Args:
sentence_list: A list of strings. The string may contain a special
[MASK] token.

Returns:
A tuple (src_tensor, dst_tensor, masked_indices, tokenized_text).
src_tensor: torch.LongTensor with shape (seq_len), the input to
the new without the last symbol and with EOS prepended.
dst_tensor: torch.LongTensor with shape (seq_len).
masked_indices: A list of indices of [MASK] in dst_tensor.
tokenized_text: A list of token string.
"""
# Split the sentence by [MASK] and tokenize the chunks independently.
tokenized_text = []
masked_indices = []
for sentence_idx, sentence in enumerate(sentence_list):
if sentence_idx > 0:
tokenized_text.append(self.eos_id)
for chunk_idx, chunk in enumerate(sentence.split('[MASK]')):
if chunk_idx > 0:
masked_indices.append(len(tokenized_text))
tokenized_text.append(self.unk_id) # use UNK as [MASK]
chunk = chunk.strip()
if chunk:
tokenized_sentence = self.tokenizer.encode(chunk)
tokenized_text.extend(tokenized_sentence)

full_indexed_tokens = [
self.bos_id
] + tokenized_text
full_tokens_tensor = torch.tensor(full_indexed_tokens)
src_tensor = full_tokens_tensor[:-1]
dst_tensor = full_tokens_tensor[1:]

tokenized_text = self.tokenizer.decode(tokenized_text)

return src_tensor, dst_tensor, masked_indices, tokenized_text

def get_batch_generation(self, sentences_list, logger=None, try_cuda=True):
if try_cuda:
self.try_cuda()
src_tensor_list, dst_tensor_list, masked_indices_list, _ = zip(*[
self.__get_input_tensors(sentences) for sentences in sentences_list
])

src_tensor_batch = torch.nn.utils.rnn.pad_sequence(
src_tensor_list, batch_first=True)

# The model uses shared embedding space for tokens and positions. More
# precisely, the first len(vocab) indidices are reseved for words, the
# last n_special symbols are reserved for special symbols and the rest
# is used for positions. Softmax and embedding matrices are shared and
# as result some of output "symbols" correspond to positions. To fix
# that we have to manually remove logits for positions.
with torch.no_grad():
logits = self.gpt_model(src_tensor_batch.to(self._model_device))[0]
logits = logits[..., :self.gpt_model.config.vocab_size]

log_probs = torch.nn.functional.log_softmax(logits, dim=-1).cpu()

token_ids_list = [
np.array(dst_tensor.numpy()) for dst_tensor in dst_tensor_list
]

return log_probs, token_ids_list, masked_indices_list

def get_contextual_embeddings(self, sentences_list, try_cuda=True):

if try_cuda:
self.try_cuda()

src_tensor_list, dst_tensor_list, masked_indices_list, _ = zip(*[
self.__get_input_tensors(sentences) for sentences in sentences_list
])

src_tensor_batch = torch.nn.utils.rnn.pad_sequence(
src_tensor_list, batch_first=True)

with torch.no_grad():
output = self.gpt_model.transformer(src_tensor_batch.to(self._model_device))

# TODO
sentence_lengths = None
tokenized_text_list = None

# As we only return the last layer, [] to have the same format as other models
return [output], sentence_lengths, tokenized_text_list


Loading