diff --git a/scripts/conversion_toolkits/convert_gpt2.py b/scripts/conversion_toolkits/convert_gpt2.py index 7efe720922..fc23ed9809 100644 --- a/scripts/conversion_toolkits/convert_gpt2.py +++ b/scripts/conversion_toolkits/convert_gpt2.py @@ -170,8 +170,8 @@ def test_model(tf_model_path, gluon_model): # gluon model gl_input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=ctx) - gl_logits_1, gl_states = gluon_model(gl_input_ids, gl_start_states, mx.np.array(0, dtype=np.int32, ctx=ctx)) - gl_logits_2, _ = gluon_model(gl_input_ids, gl_states, mx.np.array(seq_length, dtype=np.int32, ctx=ctx)) + gl_logits_1, gl_states = gluon_model(gl_input_ids, gl_start_states) + gl_logits_2, _ = gluon_model(gl_input_ids, gl_states) # tf model with tf.Session(graph=tf.Graph()) as sess: diff --git a/scripts/conversion_toolkits/convert_gpt2.sh b/scripts/conversion_toolkits/convert_gpt2.sh index a551250c4b..febc2ec2f8 100644 --- a/scripts/conversion_toolkits/convert_gpt2.sh +++ b/scripts/conversion_toolkits/convert_gpt2.sh @@ -1,6 +1,6 @@ python3 -m pip install tensorflow==1.15 --upgrade --user git clone https://github.com/openai/gpt-2.git gpt_2 -for model in 124M 355M 774M +for model in 124M 355M 774M 1558M do python3 gpt_2/download_model.py ${model} mkdir gpt2_${model} diff --git a/scripts/generate/README.md b/scripts/generate/README.md new file mode 100644 index 0000000000..b009891d3d --- /dev/null +++ b/scripts/generate/README.md @@ -0,0 +1,82 @@ +Some of the examples below may include Unicode text characters. Set the environment variable: +```bash +export PYTHONIOENCODING=UTF-8 +``` + +Use the following command to generate gpt2 unconditional samples +```bash +python3 generate_unconditional_gpt2_samples.py \ + --model_name gpt2_774M \ + --gpu 0 \ + --temperature 0.7 \ + --top_k 40 \ + --nsamples 1000 > samples +``` + + +Interactively generate gpt2 conditioanl samples +```bash +python3 interactive_conditional_gpt2_samples.py \ + --model_name gpt2_774M \ + --nsamples 1 +``` + +Calculate some metrics in https://arxiv.org/pdf/1904.09751.pdf. +These metrics are just heuristics and there is no guarantee that they correlates well with human evaluation. +```bash +python3 calculate_metrics.py \ + --file samples +``` + + +Some metrics for the unconditional generated text + +| GPT2 774M | Self-BLEU4 |Zipf Coefficient| Repetition % | +|---------------|----------------|----------------|----------------| +| pure sampling | 0.2701 | 0.9522 | 0.0 | +| original gpt2 | 0.2750 | 0.9512 | 0.0 | +| t=0.9 | 0.3683 | 0.9619 | 0.1 | +| topk=40 | 0.4291 | 0.9666 | 0.0 | +| topk=640 | 0.3384 | 0.9623 | 0.0 | +| topk=40 t=0.7 | 0.4621 | 0.9586 | 1.1 | + + +Part of some interesting generated unconditional example + + +A story +``` +Looking back, Dil shook his head. The doll market was growing at an extraordinary rate; in his own opinion, it was unwarranted since his brother was sold to an abandoned bank. He was aware of what he had to do and was sure where his family was going; the thoughts worried him. + +Although his brother had already grown an incredibly bulky gig paperback, he had never read a novel with an arguably more sinister turn. The intellectual gift of a child was reserved for reciting worked examples. As usual, exploiting loopholes, smart brother had practiced the art of overacting. Those tricks that remained medicinal classes grew weaker and smaller; in the end, one could not predict the fruition of those fighting skills. + +Although he knew of a possible method of dealing with the right-winger, although he did not get his brother's hands on it, Regulus had already leaked his intentions in searching for Dil. He had already rushed passengers directly including that stupid bull. Due to the numerous setback, while Dil had luckily survived, he still suffered a decrease in his power. + +He was reminded of the real reason why keeping secrets was not worth nothing; one must develop ones latent talents; in order to reverse one's stage of development all one had to do was give lessons to an opposite-type STUDENT that had similar abilities to those those that were bestowed by the parents; it was thus necessary to sift through the cat and mouse game over the years for those that had true deficiencies. +``` + +Code with comments +``` +struct Read { + +index: usize , + +size: usize , + +} + +extern crate glob; + +/// A function indexed by some unique index (The &uniqueID Thing will become the + +/// @value@). + +struct Parse { + +index: usize , + +index64: usize , + +} + +``` diff --git a/scripts/generate/calculate_metrics.py b/scripts/generate/calculate_metrics.py new file mode 100644 index 0000000000..72b7b74c36 --- /dev/null +++ b/scripts/generate/calculate_metrics.py @@ -0,0 +1,109 @@ +import os +import re +import argparse +from nltk.translate.bleu_score import sentence_bleu +from collections import Counter +import operator +import numpy as np +from scipy import stats +import random +from tqdm import tqdm +from functools import partial +from multiprocessing.pool import Pool +from gluonnlp.models.gpt2 import get_pretrained_gpt2 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Calculate metrics for the generated sentences') + parser.add_argument('--file', type=str, required=True, help='Model name') + parser.add_argument('--num_samples', type=int, default=1000, help='') + parser.add_argument('--num_bleu_samples', type=int, default=1000, help='') + return parser.parse_args() + + +def bleu(sample_strs, i): + return sentence_bleu( + hypothesis=sample_strs[i], + references=sample_strs[:i] + sample_strs[i+1:], + weights=(0.25, 0.25, 0.25, 0.25) + ) + + +def calculate_self_bleu4(sample_strs, num_bleu_samples): + """Self-BLEU is calculated by computing the BLEU score of each generated document + using all other generations in the evaluation set as references. + """ + bleu_scores = [] + pool = Pool(processes=os.cpu_count()) + bleu_scores = list(tqdm( + pool.imap_unordered( + partial(bleu, sample_strs), + random.sample(range(len(sample_strs)), num_bleu_samples)), + total=num_bleu_samples + )) + return sum(bleu_scores) / num_bleu_samples + + +def calculate_zipf_coefficient(sample_ids, tokenizer): + """The Zipfian coefficient (R-squared) can be used to compare the distribution in a given + text to a theoretically perfect exponential curve. + """ + cnt = Counter() + for sample_id in sample_ids: + cnt.update(sample_id) + + xs = np.arange(1, min(len(cnt), len(tokenizer.vocab)) + 1) + ys = np.array(sorted(cnt.values(), key=operator.neg)[:len(tokenizer.vocab)]) + _, _, r, _, _ = stats.linregress(np.log(xs), np.log(ys)) + return r ** 2 + + +def calculate_repetition(sample_ids): + """The repetition rate in generated samples. + """ + max_n = 90 + n_repeated_examples = 0 + for sample_id in sample_ids: + rev = list(reversed(sample_id)) + last_n_repeats = [0 for _ in range(max_n)] + for n in range(1, max_n + 1): + n_repeat = 1 + while len(rev[n*n_repeat:n*(n_repeat+1)]) == n and \ + rev[n*n_repeat:n*(n_repeat+1)] == rev[:n]: + n_repeat += 1 + last_n_repeats[n-1] = n_repeat + max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x]) + if last_n_repeats[max_repeated_n] > 1 and (max_repeated_n+1 >= 3 or last_n_repeats[max_repeated_n] > 50): + n_repeated_examples += 1 + return n_repeated_examples / len(sample_ids) + + +def calculate_metrics(args): + with open(args.file, encoding='utf-8') as of: + samples = of.read() + pattern = '='*40 + ' SAMPLE \d+ ' + '='*40 + '\n' + samples = re.split(pattern, samples)[1:] + samples = samples[:args.num_samples] + assert len(samples) == args.num_samples + + _, tokenizer, _, _ = get_pretrained_gpt2( + load_backbone=False, + load_lm=False) + sample_ids = tokenizer.encode(samples, output_type=int) + if sample_ids[-1] == tokenizer.vocab.eos_id: + sample_ids.pop() + sample_strs = tokenizer.encode(samples, output_type=str) + + self_bleu4 = calculate_self_bleu4(sample_strs, args.num_bleu_samples) + zipf_coefficient = calculate_zipf_coefficient(sample_ids, tokenizer) + repetition = calculate_repetition(sample_ids) + print('Self BLEU 4: {}\n' + 'Zipf coefficient: {}\n' + 'Repetition: {}\n' + .format(self_bleu4, zipf_coefficient, repetition)) + + +if __name__ == '__main__': + args = parse_args() + calculate_metrics(args) diff --git a/scripts/generate/generate_unconditional_gpt2_samples.py b/scripts/generate/generate_unconditional_gpt2_samples.py new file mode 100644 index 0000000000..aa7cf3ecab --- /dev/null +++ b/scripts/generate/generate_unconditional_gpt2_samples.py @@ -0,0 +1,105 @@ +import os +import mxnet as mx +import argparse +from gluonnlp.utils import set_seed +from gluonnlp.sequence_sampler import BeamSearchSampler, BaseStepDecoder +from gluonnlp.models.gpt2 import GPT2ForLM, list_pretrained_gpt2, get_pretrained_gpt2 + +mx.npx.set_np() + +def parse_args(): + parser = argparse.ArgumentParser( + description='GPT-2 unconditional sampler. Load a GPT-2 model and sample.') + parser.add_argument('--model_name', type=str, default='gpt2_124M', + choices=list_pretrained_gpt2(), help='Model name') + parser.add_argument('--seed', type=int, default=None, help='The random seed') + parser.add_argument('--nsamples', type=int, default=0, help='Number of samples to return') + parser.add_argument('--batch_size', type=int, default=1, help='Number of batches') + parser.add_argument('--length', type=int, default=None, + help='Number of tokens in generated text, if None (default), is ' + 'determined by model max_length') + parser.add_argument('--temperature', type=float, default=1.0, + help='') + parser.add_argument('--top_k', type=int, default=-1, + help='Multinomial sampling with topk, ' + 'see [ACL2018] "Hierarchical Neural Story Generation"' + 'https://www.aclweb.org/anthology/P18-1082.pdf') + parser.add_argument('--top_p', type=float, default=-1.0, + help='Multinomial sampling with topp, ' + 'see [ICLR2020] "The Curious Case of Neural Text Degeneration"' + 'https://arxiv.org/abs/1904.09751') + parser.add_argument('--gpu', type=int, default=0, + help='Which gpu to use, set None to use cpu') + return parser.parse_args() + + +class GPT2Decoder(BaseStepDecoder): + def __init__(self, gpt2_lm_model): + self._gpt2_lm_model = gpt2_lm_model + @property + def state_batch_axis(self): + return 2 if self._gpt2_lm_model._backbone_model.layout == 'NT' else 3 + def init_states(self, batch_size, ctx): + return self._gpt2_lm_model.init_states(batch_size, ctx) + def __call__(self, data, states): + data = mx.npx.reshape(data, (-1, 1)) + logits, new_states = self._gpt2_lm_model(data, states) + return logits[:,-1,:], new_states + + +def sample_gpt2(args): + ctx = mx.gpu(args.gpu) if args.gpu is not None else \ + mx.cpu() + + cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2( + model_name=args.model_name, + load_backbone=False, + load_lm=True) + + if args.length is None: + args.length = cfg.MODEL.max_length + assert args.length <= cfg.MODEL.max_length, \ + "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length) + + model = GPT2ForLM(cfg) + model.hybridize() + model.load_parameters(lm_params_path, ctx=ctx) + gpt2decoder = GPT2Decoder(model) + + sampler = BeamSearchSampler( + beam_size=1, + decoder=gpt2decoder, + eos_id=None, + vocab_size=cfg.MODEL.vocab_size, + max_length_a=0, + max_length_b=args.length, + min_length=1, + temperature=args.temperature, + sampling=True, + sampling_topp=args.top_p, + sampling_topk=args.top_k, + early_return=False + ) + + start_input = mx.np.full((args.batch_size, 1), tokenizer.vocab.eos_id, ctx=ctx) + start_states = gpt2decoder.init_states(args.batch_size, ctx) + + generated = 0 + while args.nsamples <= 0 or generated < args.nsamples: + samples, _, _ = sampler(start_input, start_states) + for i in range(args.batch_size): + generated += 1 + ids = samples[i][0].asnumpy().tolist() + ids = ids[1:ids.index(-1)] if -1 in ids else \ + ids[1:] + text = tokenizer.decode(ids) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + + +if __name__ == '__main__': + os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' + args = parse_args() + if args.seed is not None: + set_seed(args.seed) + sample_gpt2(args) diff --git a/scripts/generate/interactive_conditional_gpt2_samples.py b/scripts/generate/interactive_conditional_gpt2_samples.py new file mode 100644 index 0000000000..8d3c7c85e0 --- /dev/null +++ b/scripts/generate/interactive_conditional_gpt2_samples.py @@ -0,0 +1,112 @@ +import os +import mxnet as mx +import argparse +from gluonnlp.utils import set_seed +from gluonnlp.sequence_sampler import BeamSearchSampler, BaseStepDecoder +from gluonnlp.models.gpt2 import GPT2ForLM, list_pretrained_gpt2, get_pretrained_gpt2 + +mx.npx.set_np() + +def parse_args(): + parser = argparse.ArgumentParser( + description='GPT-2 unconditional sampler. Load a GPT-2 model and sample.') + parser.add_argument('--model_name', type=str, default='gpt2_124M', + choices=list_pretrained_gpt2(), help='Model name') + parser.add_argument('--seed', type=int, default=None, help='The random seed') + parser.add_argument('--nsamples', type=int, default=0, help='Number of samples to return') + parser.add_argument('--batch_size', type=int, default=1, help='Number of batches') + parser.add_argument('--length', type=int, default=None, + help='Number of tokens in generated text, if None (default), is ' + 'determined by model max_length') + parser.add_argument('--temperature', type=float, default=1.0, + help='') + parser.add_argument('--top_k', type=int, default=-1, + help='Multinomial sampling with topk, ' + 'see [ACL2018] "Hierarchical Neural Story Generation"' + 'https://www.aclweb.org/anthology/P18-1082.pdf') + parser.add_argument('--top_p', type=float, default=-1.0, + help='Multinomial sampling with topp, ' + 'see [ICLR2020] "The Curious Case of Neural Text Degeneration"' + 'https://arxiv.org/abs/1904.09751') + parser.add_argument('--gpu', type=int, default=0, + help='Which gpu to use, set None to use cpu') + return parser.parse_args() + + +class GPT2Decoder(BaseStepDecoder): + def __init__(self, gpt2_lm_model): + self._gpt2_lm_model = gpt2_lm_model + @property + def state_batch_axis(self): + return 2 if self._gpt2_lm_model._backbone_model.layout == 'NT' else 3 + def init_states(self, batch_size, ctx): + return self._gpt2_lm_model.init_states(batch_size, ctx) + def __call__(self, data, states): + data = mx.npx.reshape(data, (-2, -1)) + logits, new_states = self._gpt2_lm_model(data, states) + return logits[:,-1,:], new_states + + +def sample_gpt2(args): + ctx = mx.gpu(args.gpu) if args.gpu is not None else \ + mx.cpu() + + cfg, tokenizer, _, lm_params_path = get_pretrained_gpt2( + model_name=args.model_name, + load_backbone=False, + load_lm=True) + + if args.length is None: + args.length = cfg.MODEL.max_length + assert args.length <= cfg.MODEL.max_length, \ + "Can't get samples longer than window size: {}".format(cfg.MODEL.max_length) + + model = GPT2ForLM(cfg) + model.hybridize() + model.load_parameters(lm_params_path, ctx=ctx) + gpt2decoder = GPT2Decoder(model) + + sampler = BeamSearchSampler( + beam_size=1, + decoder=gpt2decoder, + eos_id=None, + vocab_size=cfg.MODEL.vocab_size, + max_length_a=0, + max_length_b=args.length, + min_length=1, + temperature=args.temperature, + sampling=True, + sampling_topp=args.top_p, + sampling_topk=args.top_k, + early_return=False + ) + start_states = gpt2decoder.init_states(args.batch_size, ctx) + + while True: + raw_text = input('Model prompt >>> ') + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") + context_tokens = tokenizer.encode(raw_text, output_type=int) + start_input = mx.np.repeat(mx.np.expand_dims(mx.np.array(context_tokens, ctx=ctx), 0), + args.batch_size, + axis=0) + generated = 0 + while generated < args.nsamples: + samples, _, _ = sampler(start_input, start_states) + for i in range(args.batch_size): + generated += 1 + ids = samples[i][0].asnumpy().tolist() + ids = ids[1:ids.index(-1)] if -1 in ids else \ + ids[1:] + text = tokenizer.decode(ids) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) + +if __name__ == '__main__': + os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' + args = parse_args() + if args.seed is not None: + set_seed(args.seed) + sample_gpt2(args) diff --git a/src/gluonnlp/models/gpt2.py b/src/gluonnlp/models/gpt2.py index f24bad0317..206bd7b372 100644 --- a/src/gluonnlp/models/gpt2.py +++ b/src/gluonnlp/models/gpt2.py @@ -96,6 +96,16 @@ def gpt2_774M(): cfg.freeze() return cfg +@gpt2_cfg_reg.register() +def gpt2_1558M(): + cfg = gpt2_124M() + cfg.defrost() + cfg.MODEL.num_heads = 25 + cfg.MODEL.num_layers = 48 + cfg.MODEL.units = 1600 + cfg.freeze() + return cfg + PRETRAINED_URL = { 'gpt2_124M': { 'cfg': gpt2_124M(), @@ -118,6 +128,13 @@ def gpt2_774M(): 'params': 'gpt2_774M/model-9917e24e.params', 'lm_params': 'gpt2_774M/model_lm-cfbfa641.params' }, + 'gpt2_1558M': { + 'cfg': gpt2_1558M(), + 'merges': 'gpt2_1558M/gpt2-396d4d8e.merges', + 'vocab': 'gpt2_1558M/gpt2-9dc62091.vocab', + 'params': 'gpt2_1558M/model-af3dd713.params', + 'lm_params': 'gpt2_1558M/model_lm-c8489dcb.params' + }, } @@ -180,7 +197,7 @@ def __init__(self, units: int = 768, ) self.hidden_dropout = nn.Dropout(self._hidden_dropout_prob) - def hybrid_forward(self, F, x, layer_states, prev_len): + def hybrid_forward(self, F, x, layer_states): """ Parameters @@ -195,13 +212,14 @@ def hybrid_forward(self, F, x, layer_states, prev_len): Shape (2, batch_size, prev_len, C_in) - layout = 'TN' Shape (2, prev_len, batch_size, C_in) - prev_len """ x = self.ln(x) if self._layout == 'NT': batch_axis, time_axis = 0, 1 + prev_len = F.npx.shape_array(layer_states)[2] else: batch_axis, time_axis = 1, 0 + prev_len = F.npx.shape_array(layer_states)[1] query, key, value = F.np.split(self.qkv(x), 3, axis=-1) if layer_states is not None: @@ -333,7 +351,7 @@ def __init__(self, units: int = 768, dtype=self._dtype ) - def hybrid_forward(self, F, x, layer_states, prev_len): + def hybrid_forward(self, F, x, layer_states): """ Parameters @@ -349,8 +367,6 @@ def hybrid_forward(self, F, x, layer_states, prev_len): Shape (2, batch_size, prev_len, C_in) - layout = 'TN' Shape (2, prev_len, batch_size, C_in) - prev_len - The previous length Returns ------- @@ -366,7 +382,7 @@ def hybrid_forward(self, F, x, layer_states, prev_len): - layout = 'TN' Shape (2, prev_len + seq_length, batch_size, C_in) """ - h, new_layer_states = self.atten(x, layer_states, prev_len) + h, new_layer_states = self.atten(x, layer_states) x = x + h h = self.ffn(x) return h, new_layer_states @@ -451,7 +467,7 @@ def __init__(self, def layout(self): return self._layout - def hybrid_forward(self, F, x, states, prev_len): + def hybrid_forward(self, F, x, states): """ Parameters @@ -468,8 +484,6 @@ def hybrid_forward(self, F, x, states, prev_len): Shape (num_layers, 2, batch_size, prev_len, C_in)] - layout = 'TN' Shape (num_layers, 2, prev_len, batch_size, C_in)] - prev_len - The previous length. It will be a scalar. Returns ------- @@ -486,6 +500,8 @@ def hybrid_forward(self, F, x, states, prev_len): - layout = 'TN' Shape (num_layers, 2, prev_len + seq_length, batch_size, C_in) """ + prev_len = F.npx.shape_array(states)[3] if self._layout == 'NT' else \ + F.npx.shape_array(states)[2] x = self.get_initial_embedding(F, x, prev_len) if self._layout != self._compute_layout: @@ -495,7 +511,7 @@ def hybrid_forward(self, F, x, states, prev_len): new_states = [] for layer_idx in range(self._num_layers): layer_states = None if states is None else states[layer_idx] - x, new_layer_states = self._layers[layer_idx](x, layer_states, prev_len) + x, new_layer_states = self._layers[layer_idx](x, layer_states) new_states.append(new_layer_states) new_states = F.np.stack(new_states, axis=0) @@ -609,7 +625,7 @@ def __init__(self, backbone_cfg=None): ) self._lm_head.weight = self._backbone_model._embed.weight - def hybrid_forward(self, F, inputs, states, prev_len): + def hybrid_forward(self, F, inputs, states): """Getting the logits Parameters @@ -626,8 +642,6 @@ def hybrid_forward(self, F, inputs, states, prev_len): Shape (num_layers, 2, batch_size, prev_len, C_in) - layout = 'TN' Shape (num_layers, 2, prev_len, batch_size, C_in) - prev_len - Will be a scalar that represents the previous length Returns ------- @@ -642,7 +656,7 @@ def hybrid_forward(self, F, inputs, states, prev_len): - layout = 'TN' Shape (num_layers, 2, prev_len + seq_length, batch_size, C_in) """ - contextual_embeddings, new_states = self._backbone_model(inputs, states, prev_len) + contextual_embeddings, new_states = self._backbone_model(inputs, states) logits = self._lm_head(contextual_embeddings) return logits, new_states diff --git a/src/gluonnlp/models/model_zoo_checksums/gpt2.txt b/src/gluonnlp/models/model_zoo_checksums/gpt2.txt index f117b813d1..0f4af681fa 100644 --- a/src/gluonnlp/models/model_zoo_checksums/gpt2.txt +++ b/src/gluonnlp/models/model_zoo_checksums/gpt2.txt @@ -1,15 +1,16 @@ -gpt2_124M/model-fac1f39c.yml fac1f39c804e324c69162b9b37bd24ab98241612 424 -gpt2_124M/model_lm-99b90604.params 99b9060488b4542ccd045c28401da10a3158ca80 497771820 -gpt2_124M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -gpt2_124M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 -gpt2_124M/model-bfed311d.params bfed311d5c980ba475f90ccf7f536d25c3b40386 497769466 -gpt2_355M/model-2aea05ff.yml 2aea05ff1e67ef816b3f824102da8b7b1292a620 425 -gpt2_355M/model_lm-eed0e964.params eed0e964f4222823a557acfee2c106f228ce0188 1419317644 -gpt2_355M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -gpt2_355M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 -gpt2_355M/model-81dee612.params 81dee612413733899f6e5fbbeac91da781805e1b 1419312986 -gpt2_774M/model-c9555788.yml c95557880783ec4f94b09b5b045c8d9e9a198e4d 425 -gpt2_774M/model_lm-cfbfa641.params cfbfa6419aaf1eae480fba5a1a7c8ea6096d43d6 3096157676 -gpt2_774M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -gpt2_774M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 -gpt2_774M/model-9917e24e.params 9917e24e89c651793adea69042d6cceddfc7973c 3096150714 +gpt2_124M/model_lm-99b90604.params 99b9060488b4542ccd045c28401da10a3158ca80 497771820 +gpt2_124M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +gpt2_124M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 +gpt2_124M/model-bfed311d.params bfed311d5c980ba475f90ccf7f536d25c3b40386 497769466 +gpt2_355M/model_lm-eed0e964.params eed0e964f4222823a557acfee2c106f228ce0188 1419317644 +gpt2_355M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +gpt2_355M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 +gpt2_355M/model-81dee612.params 81dee612413733899f6e5fbbeac91da781805e1b 1419312986 +gpt2_774M/model_lm-cfbfa641.params cfbfa6419aaf1eae480fba5a1a7c8ea6096d43d6 3096157676 +gpt2_774M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +gpt2_774M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 +gpt2_774M/model-9917e24e.params 9917e24e89c651793adea69042d6cceddfc7973c 3096150714 +gpt2_1558M/model_lm-c8489dcb.params c8489dcbdb0d39bc3eac6d1d62e0e3dace9faa8f 6230494540 +gpt2_1558M/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +gpt2_1558M/gpt2-9dc62091.vocab 9dc620913410d5ec1a988abf852891e1c9f0f649 558055 +gpt2_1558M/model-af3dd713.params af3dd71313b55b4be5f52bdd538c9db054c1e190 6230485274 diff --git a/src/gluonnlp/sequence_sampler.py b/src/gluonnlp/sequence_sampler.py index 36b98468cf..6ad1f1e035 100644 --- a/src/gluonnlp/sequence_sampler.py +++ b/src/gluonnlp/sequence_sampler.py @@ -577,7 +577,7 @@ def forward(self, inputs, states, src_seq_lengths=None): scores = mx.np.zeros(shape=(batch_size, beam_size), ctx=ctx) if beam_size > 1: scores[:, 1:beam_size] = LARGE_NEGATIVE_FLOAT - samples = step_input.reshape((batch_size, beam_size, 1)) + samples = step_input.reshape((batch_size, beam_size, -1)) batch_shift = mx.np.arange(0, batch_size * beam_size, beam_size, ctx=ctx, dtype=mx.np.int32) step = mx.np.array(0, ctx=ctx, dtype=mx.np.float32) for i in range(max_length): diff --git a/tests/test_models_gpt2.py b/tests/test_models_gpt2.py index 25f3ef6977..1b510ba332 100644 --- a/tests/test_models_gpt2.py +++ b/tests/test_models_gpt2.py @@ -41,16 +41,14 @@ def test_gpt2_small_config(compute_layout, ctx): gpt2_model.hybridize() hiddens, _ = gpt2_model( inputs, - gpt2_model.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_model.init_states(batch_size, ctx) ) gpt2_model_tn = GPT2Model.from_cfg(cfg_tn) gpt2_model_tn.share_parameters(gpt2_model.collect_params()) gpt2_model_tn.hybridize() hiddens_tn, _ = gpt2_model_tn( inputs.T, - gpt2_model_tn.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_model_tn.init_states(batch_size, ctx) ) assert_allclose(np.swapaxes(hiddens_tn.asnumpy(), 0, 1), hiddens.asnumpy(), 1E-4, 1E-4) @@ -61,16 +59,14 @@ def test_gpt2_small_config(compute_layout, ctx): gpt2_lm_model.hybridize() logits, states = gpt2_lm_model( inputs, - gpt2_lm_model.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_lm_model.init_states(batch_size, ctx) ) gpt2_lm_model_tn = GPT2ForLM(cfg_tn) gpt2_lm_model_tn.share_parameters(gpt2_lm_model.collect_params()) gpt2_lm_model_tn.hybridize() logits_tn, states_tn = gpt2_lm_model_tn( inputs.T, - gpt2_lm_model_tn.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_lm_model_tn.init_states(batch_size, ctx) ) assert_allclose(np.swapaxes(logits_tn.asnumpy(), 0, 1), logits.asnumpy(), 1E-4, 1E-4) @@ -91,8 +87,7 @@ def test_gpt2_incremental_states(ctx): one_time_hiddens, one_time_states = gpt2_model( inputs, - gpt2_model.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_model.init_states(batch_size, ctx) ) states = gpt2_model.init_states(batch_size, ctx) @@ -100,8 +95,7 @@ def test_gpt2_incremental_states(ctx): for i in range(sequence_length): hiddens, states = gpt2_model( inputs[:, i:i+1], - states, - mx.np.array(i, dtype=np.int32, ctx=ctx) + states ) hiddens_l.append(hiddens) hiddens_concat = mx.np.concatenate(hiddens_l, axis=1) @@ -113,7 +107,7 @@ def test_gpt2_incremental_states(ctx): @pytest.mark.slow @pytest.mark.remote_required -@pytest.mark.parametrize('model_name', list_pretrained_gpt2()) +@pytest.mark.parametrize('model_name', ['gpt2_124M', 'gpt2_355M', 'gpt2_774M']) def test_gpt2(model_name, ctx): # test from pretrained assert len(list_pretrained_gpt2()) > 0 @@ -143,8 +137,7 @@ def test_gpt2(model_name, ctx): ) logits, _ = gpt2_lm_model( input_ids, - gpt2_lm_model.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_lm_model.init_states(batch_size, ctx) ) mx.npx.waitall() # test backward @@ -152,8 +145,7 @@ def test_gpt2(model_name, ctx): with mx.autograd.record(): logits, _ = gpt2_lm_model( input_ids, - gpt2_lm_model.init_states(batch_size, ctx), - mx.np.array(0, dtype=np.int32, ctx=ctx) + gpt2_lm_model.init_states(batch_size, ctx) ) loss = label_smooth_loss(logits, input_ids) loss.backward()