This repository has been archived by the owner on Jan 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 535
[FEATURE] gpt2 generation scripts #1354
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
65f3e75
Merge pull request #1 from dmlc/master
hutao965 bab8b65
remove prev_len in hybrid_forward parameters
b8268ae
update
fe5a93f
sample
c5d5ef3
update
59ce7cb
add gpt2_1558M
aae228c
update
e6c75fb
update
4ee0ea6
update
e300649
update
a272d8d
update
d7d5ac0
update
e2adfd2
update
00feb36
Merge pull request #2 from dmlc/master
hutao965 7854a9b
update
260d74b
update
be64347
update
57797c2
update
a0ead28
update
42da589
update
887b3b6
update
0df7c70
update
df86748
update
d84f53f
update
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
Some of the examples below may include Unicode text characters. Set the environment variable: | ||
```bash | ||
export PYTHONIOENCODING=UTF-8 | ||
``` | ||
|
||
Use the following command to generate gpt2 unconditional samples | ||
```bash | ||
python3 generate_unconditional_gpt2_samples.py \ | ||
--model_name gpt2_774M \ | ||
--gpu 0 \ | ||
--temperature 0.7 \ | ||
--top_k 40 \ | ||
--nsamples 1000 > samples | ||
``` | ||
|
||
|
||
Interactively generate gpt2 conditioanl samples | ||
```bash | ||
python3 interactive_conditional_gpt2_samples.py \ | ||
--model_name gpt2_774M \ | ||
--nsamples 1 | ||
``` | ||
|
||
Calculate some metrics in https://arxiv.org/pdf/1904.09751.pdf. | ||
These metrics are just heuristics and there is no guarantee that they correlates well with human evaluation. | ||
```bash | ||
python3 calculate_metrics.py \ | ||
--file samples | ||
``` | ||
|
||
|
||
Some metrics for the unconditional generated text | ||
|
||
| GPT2 774M | Self-BLEU4 |Zipf Coefficient| Repetition % | | ||
|---------------|----------------|----------------|----------------| | ||
| pure sampling | 0.2701 | 0.9522 | 0.0 | | ||
| original gpt2 | 0.2750 | 0.9512 | 0.0 | | ||
| t=0.9 | 0.3683 | 0.9619 | 0.1 | | ||
| topk=40 | 0.4291 | 0.9666 | 0.0 | | ||
| topk=640 | 0.3384 | 0.9623 | 0.0 | | ||
| topk=40 t=0.7 | 0.4621 | 0.9586 | 1.1 | | ||
|
||
|
||
Part of some interesting generated unconditional example | ||
|
||
|
||
A story | ||
``` | ||
Looking back, Dil shook his head. The doll market was growing at an extraordinary rate; in his own opinion, it was unwarranted since his brother was sold to an abandoned bank. He was aware of what he had to do and was sure where his family was going; the thoughts worried him. | ||
|
||
Although his brother had already grown an incredibly bulky gig paperback, he had never read a novel with an arguably more sinister turn. The intellectual gift of a child was reserved for reciting worked examples. As usual, exploiting loopholes, smart brother had practiced the art of overacting. Those tricks that remained medicinal classes grew weaker and smaller; in the end, one could not predict the fruition of those fighting skills. | ||
|
||
Although he knew of a possible method of dealing with the right-winger, although he did not get his brother's hands on it, Regulus had already leaked his intentions in searching for Dil. He had already rushed passengers directly including that stupid bull. Due to the numerous setback, while Dil had luckily survived, he still suffered a decrease in his power. | ||
|
||
He was reminded of the real reason why keeping secrets was not worth nothing; one must develop ones latent talents; in order to reverse one's stage of development all one had to do was give lessons to an opposite-type STUDENT that had similar abilities to those those that were bestowed by the parents; it was thus necessary to sift through the cat and mouse game over the years for those that had true deficiencies. | ||
``` | ||
|
||
Code with comments | ||
``` | ||
struct Read <T> { | ||
|
||
index: usize , | ||
|
||
size: usize , | ||
|
||
} | ||
|
||
extern crate glob; | ||
|
||
/// A function indexed by some unique index (The &uniqueID Thing will become the | ||
|
||
/// @value@). | ||
|
||
struct Parse <T: Read, U: Read, D: Read, Tore: Read> { | ||
|
||
index: usize , | ||
|
||
index64: usize , | ||
|
||
} | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import re | ||
import argparse | ||
from nltk.translate.bleu_score import sentence_bleu | ||
from collections import Counter | ||
import operator | ||
import numpy as np | ||
from scipy import stats | ||
import random | ||
from tqdm import tqdm | ||
from functools import partial | ||
from multiprocessing.pool import Pool | ||
from gluonnlp.models.gpt2 import get_pretrained_gpt2 | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Calculate metrics for the generated sentences') | ||
parser.add_argument('--file', type=str, required=True, help='Model name') | ||
parser.add_argument('--num_samples', type=int, default=1000, help='') | ||
parser.add_argument('--num_bleu_samples', type=int, default=1000, help='') | ||
return parser.parse_args() | ||
|
||
|
||
def bleu(sample_strs, i): | ||
return sentence_bleu( | ||
hypothesis=sample_strs[i], | ||
references=sample_strs[:i] + sample_strs[i+1:], | ||
weights=(0.25, 0.25, 0.25, 0.25) | ||
) | ||
|
||
|
||
def calculate_self_bleu4(sample_strs, num_bleu_samples): | ||
"""Self-BLEU is calculated by computing the BLEU score of each generated document | ||
using all other generations in the evaluation set as references. | ||
""" | ||
bleu_scores = [] | ||
pool = Pool(processes=os.cpu_count()) | ||
bleu_scores = list(tqdm( | ||
pool.imap_unordered( | ||
partial(bleu, sample_strs), | ||
random.sample(range(len(sample_strs)), num_bleu_samples)), | ||
total=num_bleu_samples | ||
)) | ||
return sum(bleu_scores) / num_bleu_samples | ||
|
||
|
||
def calculate_zipf_coefficient(sample_ids, tokenizer): | ||
"""The Zipfian coefficient (R-squared) can be used to compare the distribution in a given | ||
text to a theoretically perfect exponential curve. | ||
""" | ||
cnt = Counter() | ||
for sample_id in sample_ids: | ||
cnt.update(sample_id) | ||
|
||
xs = np.arange(1, min(len(cnt), len(tokenizer.vocab)) + 1) | ||
ys = np.array(sorted(cnt.values(), key=operator.neg)[:len(tokenizer.vocab)]) | ||
_, _, r, _, _ = stats.linregress(np.log(xs), np.log(ys)) | ||
return r ** 2 | ||
|
||
|
||
def calculate_repetition(sample_ids): | ||
"""The repetition rate in generated samples. | ||
""" | ||
max_n = 90 | ||
n_repeated_examples = 0 | ||
for sample_id in sample_ids: | ||
rev = list(reversed(sample_id)) | ||
last_n_repeats = [0 for _ in range(max_n)] | ||
for n in range(1, max_n + 1): | ||
n_repeat = 1 | ||
while len(rev[n*n_repeat:n*(n_repeat+1)]) == n and \ | ||
rev[n*n_repeat:n*(n_repeat+1)] == rev[:n]: | ||
n_repeat += 1 | ||
last_n_repeats[n-1] = n_repeat | ||
max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x]) | ||
if last_n_repeats[max_repeated_n] > 1 and (max_repeated_n+1 >= 3 or last_n_repeats[max_repeated_n] > 50): | ||
n_repeated_examples += 1 | ||
return n_repeated_examples / len(sample_ids) | ||
|
||
|
||
def calculate_metrics(args): | ||
with open(args.file, encoding='utf-8') as of: | ||
samples = of.read() | ||
pattern = '='*40 + ' SAMPLE \d+ ' + '='*40 + '\n' | ||
samples = re.split(pattern, samples)[1:] | ||
samples = samples[:args.num_samples] | ||
assert len(samples) == args.num_samples | ||
|
||
_, tokenizer, _, _ = get_pretrained_gpt2( | ||
load_backbone=False, | ||
load_lm=False) | ||
sample_ids = tokenizer.encode(samples, output_type=int) | ||
if sample_ids[-1] == tokenizer.vocab.eos_id: | ||
sample_ids.pop() | ||
sample_strs = tokenizer.encode(samples, output_type=str) | ||
|
||
self_bleu4 = calculate_self_bleu4(sample_strs, args.num_bleu_samples) | ||
zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer) | ||
repetition = calculate_repetition(sample_ids) | ||
print('Self BLEU 4: {}\n' | ||
'Zipf coefficient: {}\n' | ||
'Repetition: {}\n' | ||
.format(self_bleu4, zipf_coefficient, repetition)) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
calculate_metrics(args) |
105 changes: 105 additions & 0 deletions
105
scripts/generate/generate_unconditional_gpt2_samples.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import mxnet as mx | ||
import argparse | ||
from gluonnlp.utils import set_seed | ||
from gluonnlp.sequence_sampler import BeamSearchSampler, BaseStepDecoder | ||
from gluonnlp.models.gpt2 import GPT2ForLM, list_pretrained_gpt2, get_pretrained_gpt2 | ||
|
||
mx.npx.set_np() | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='GPT-2 unconditional sampler. Load a GPT-2 model and sample.') | ||
parser.add_argument('--model_name', type=str, default='gpt2_124M', | ||
choices=list_pretrained_gpt2(), help='Model name') | ||
parser.add_argument('--seed', type=int, default=None, help='The random seed') | ||
parser.add_argument('--nsamples', type=int, default=0, help='Number of samples to return') | ||
parser.add_argument('--batch_size', type=int, default=1, help='Number of batches') | ||
parser.add_argument('--length', type=int, default=None, | ||
help='Number of tokens in generated text, if None (default), is ' | ||
'determined by model max_length') | ||
parser.add_argument('--temperature', type=float, default=1.0, | ||
help='') | ||
parser.add_argument('--top_k', type=int, default=-1, | ||
help='Multinomial sampling with topk, ' | ||
'see [ACL2018] "Hierarchical Neural Story Generation"' | ||
'https://www.aclweb.org/anthology/P18-1082.pdf') | ||
parser.add_argument('--top_p', type=float, default=-1.0, | ||
help='Multinomial sampling with topp, ' | ||
'see [ICLR2020] "The Curious Case of Neural Text Degeneration"' | ||
'https://arxiv.org/abs/1904.09751') | ||
parser.add_argument('--gpu', type=int, default=0, | ||
help='Which gpu to use, set None to use cpu') | ||
return parser.parse_args() | ||
|
||
|
||
class GPT2Decoder(BaseStepDecoder): | ||
def __init__(self, gpt2_lm_model): | ||
self._gpt2_lm_model = gpt2_lm_model | ||
@property | ||
def state_batch_axis(self): | ||
return 2 if self._gpt2_lm_model._backbone_model.layout == 'NT' else 3 | ||
def init_states(self, batch_size, ctx): | ||
return self._gpt2_lm_model.init_states(batch_size, ctx) | ||
def __call__(self, data, states): | ||
data = mx.npx.reshape(data, (-1, 1)) | ||
logits, new_states = self._gpt2_lm_model(data, states) | ||
return logits[:,-1,:], new_states | ||
|
||
|
||
def sample_gpt2(args): | ||
ctx = mx.gpu(args.gpu) if args.gpu is not None else \ | ||
mx.cpu() | ||
|
||
cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2( | ||
model_name=args.model_name, | ||
load_backbone=False, | ||
load_lm=True) | ||
|
||
if args.length is None: | ||
args.length = cfg.MODEL.max_length | ||
assert args.length <= cfg.MODEL.max_length, \ | ||
"Can't get samples longer than window size: {}".format(cfg.MODEL.max_length) | ||
|
||
model = GPT2ForLM(cfg) | ||
model.hybridize() | ||
model.load_parameters(lm_params_path, ctx=ctx) | ||
gpt2decoder = GPT2Decoder(model) | ||
|
||
sampler = BeamSearchSampler( | ||
beam_size=1, | ||
decoder=gpt2decoder, | ||
eos_id=None, | ||
vocab_size=cfg.MODEL.vocab_size, | ||
max_length_a=0, | ||
max_length_b=args.length, | ||
min_length=1, | ||
temperature=args.temperature, | ||
sampling=True, | ||
sampling_topp=args.top_p, | ||
sampling_topk=args.top_k, | ||
early_return=False | ||
) | ||
|
||
start_input = mx.np.full((args.batch_size, 1), tokenizer.vocab.eos_id, ctx=ctx) | ||
start_states = gpt2decoder.init_states(args.batch_size, ctx) | ||
|
||
generated = 0 | ||
while args.nsamples <= 0 or generated < args.nsamples: | ||
samples, _, _ = sampler(start_input, start_states) | ||
for i in range(args.batch_size): | ||
generated += 1 | ||
ids = samples[i][0].asnumpy().tolist() | ||
ids = ids[1:ids.index(-1)] if -1 in ids else \ | ||
ids[1:] | ||
text = tokenizer.decode(ids) | ||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) | ||
print(text) | ||
|
||
|
||
if __name__ == '__main__': | ||
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' | ||
args = parse_args() | ||
if args.seed is not None: | ||
set_seed(args.seed) | ||
sample_gpt2(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add newline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct coding style should be