From 1e955cb87c1f0c27241a923e4f110f78f062bd39 Mon Sep 17 00:00:00 2001 From: smallv0221 <33639025+smallv0221@users.noreply.github.com> Date: Sun, 28 Feb 2021 20:59:19 +0800 Subject: [PATCH] Fix dataset doc and fix roberta tokenizer and update SQuAD example (#42) * fix dataset doc and fix roberta tokenizer and update SQuAD example * Change to relative link * Update annotation. * Minor fix --- docs/datasets.md | 8 +- examples/experimental/run_squad_test.py | 401 ++++++++++---- .../SQuAD/README.md | 8 +- .../SQuAD/args.py | 16 +- .../SQuAD/run_squad.py | 421 +++++++++------ paddlenlp/datasets/experimental/dataset.py | 12 +- paddlenlp/datasets/experimental/squad.py | 2 + paddlenlp/metrics/squad.py | 500 +++++++----------- paddlenlp/transformers/roberta/tokenizer.py | 2 +- 9 files changed, 794 insertions(+), 576 deletions(-) diff --git a/docs/datasets.md b/docs/datasets.md index dacaa421c1405..dfad3ed6bf88c 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -34,7 +34,7 @@ PaddleNLP提供了 | ---- | --------- | ------ | | [Conll05](https://www.cs.upc.edu/~srlconll/spec.html) | 语义角色标注数据集| `paddle.text.datasets.Conll05st`| | [MSRA_NER](https://github.com/lemonhu/NER-BERT-pytorch/tree/master/data/msra) | MSRA 命名实体识别数据集| `paddlenlp.datasets.MSRA_NER`| -| [Express_Ner](https://aistudio.baidu.com/aistudio/projectdetail/131360?channelType=0&channel=-1) | 快递单命名实体识别数据集| [express_ner](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/named_entity_recognition/express_ner/data)| +| [Express_Ner](https://aistudio.baidu.com/aistudio/projectdetail/131360?channelType=0&channel=-1) | 快递单命名实体识别数据集| [express_ner](../examples/named_entity_recognition/express_ner/data)| ## 机器翻译 @@ -47,13 +47,13 @@ PaddleNLP提供了 | 数据集名称 | 简介 | 调用方法 | | ---- | --------- | ------ | -| [CSSE COVID-19](https://github.com/CSSEGISandData/COVID-19) |约翰·霍普金斯大学系统科学与工程中心新冠病例数据 | [time_series](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/time_series)| +| [CSSE COVID-19](../examples/time_series) |约翰·霍普金斯大学系统科学与工程中心新冠病例数据 | [time_series](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/time_series)| | [UCIHousing](https://archive.ics.uci.edu/ml/datasets/Housing) | 波士顿房价预测数据集 | `paddle.text.datasets.UCIHousing`| ## 语料库 | 数据集名称 | 简介 | 调用方法 | | ---- | --------- | ------ | -| [yahoo](https://webscope.sandbox.yahoo.com/catalog.php?datatype=l&guccounter=1) | 雅虎英文语料库 | [VAE](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/text_generation/vae-seq2seq)| +| [yahoo](https://webscope.sandbox.yahoo.com/catalog.php?datatype=l&guccounter=1) | 雅虎英文语料库 | [VAE](../examples/text_generation/vae-seq2seq)| | [PTB](http://www.fit.vutbr.cz/~imikolov/rnnlm/) | Penn Treebank Dataset | `paddlenlp.datasets.PTB`| -| [1 Billon words](https://opensource.google/projects/lm-benchmark) | 1 Billion Word Language Model Benchmark R13 Output 基准语料库| [ELMo](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/language_model/elmo)| +| [1 Billon words](https://opensource.google/projects/lm-benchmark) | 1 Billion Word Language Model Benchmark R13 Output 基准语料库| [ELMo](../examples/language_model/elmo)| diff --git a/examples/experimental/run_squad_test.py b/examples/experimental/run_squad_test.py index 39a2061804a8a..deccbbfabea44 100644 --- a/examples/experimental/run_squad_test.py +++ b/examples/experimental/run_squad_test.py @@ -1,116 +1,337 @@ -from paddlenlp.datasets.experimental import SQuAD -from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer -from paddlenlp.data import Stack, Tuple, Pad, Dict +import collections +import os +import random +import time +import json + from functools import partial +import numpy as np +import paddle + from paddle.io import DataLoader -from paddlenlp.datasets.experimental import load_dataset +from args import parse_args -train_ds, dev_ds = load_dataset('squad', splits=('train_v2', 'dev_v2')) -tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +import paddlenlp as ppnlp -print(len(train_ds)) -print(len(dev_ds)) -print(train_ds[0]) +from paddlenlp.data import Pad, Stack, Tuple, Dict +from paddlenlp.transformers import BertForQuestionAnswering, BertTokenizer, ErnieForQuestionAnswering, ErnieTokenizer +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.metrics.squad import squad_evaluate, compute_prediction +from paddlenlp.datasets import load_dataset -print('-----------------------------------------------------------') +MODEL_CLASSES = { + "bert": (BertForQuestionAnswering, BertTokenizer), + "ernie": (ErnieForQuestionAnswering, ErnieTokenizer) +} -def prepare_train_features(examples): - # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results - # in one example possible giving several features when a context is long, each of those features having a - # context that overlaps a bit the context of the previous feature. - contexts = [examples[i]['context'] for i in range(5000)] - questions = [examples[i]['question'] for i in range(5000)] +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + paddle.seed(args.seed) - tokenized_examples = tokenizer( - questions, contexts, stride=128, max_seq_len=384) - print(len(tokenized_examples)) - # Since one example might give us several features if it has a long context, we need a map from a feature to - # its corresponding example. This key gives us just that. +def evaluate(model, data_loader, args): + model.eval() - # The offset mappings will give us a map from token to character position in the original context. This will - # help us compute the start_positions and end_positions. + all_start_logits = [] + all_end_logits = [] + tic_eval = time.time() - # Let's label those examples! + for batch in data_loader: + input_ids, segment_ids = batch + start_logits_tensor, end_logits_tensor = model(input_ids, segment_ids) - for i, tokenized_example in enumerate(tokenized_examples): - # We will label impossible answers with the index of the CLS token. - input_ids = tokenized_example["input_ids"] - cls_index = input_ids.index(tokenizer.cls_token_id) - offsets = tokenized_example['offset_mapping'] + for idx in range(start_logits_tensor.shape[0]): + if len(all_start_logits) % 1000 == 0 and len(all_start_logits): + print("Processing example: %d" % len(all_start_logits)) + print('time per 1000:', time.time() - tic_eval) + tic_eval = time.time() - # Grab the sequence corresponding to that example (to know what is the context and what is the question). - sequence_ids = tokenized_example['segment_ids'] + all_start_logits.append(start_logits_tensor.numpy()[idx]) + all_end_logits.append(end_logits_tensor.numpy()[idx]) - # One example can give several spans, this is the index of the example containing this span of text. - sample_index = tokenized_example['overflow_to_sample'] - answers = examples[sample_index]['answers'] - answer_starts = examples[sample_index]['answer_starts'] - if i == 0: - print(answer_starts) - print(len(answer_starts)) - # If no answers are given, set the cls_index as answer. - if len(answer_starts) == 0: - tokenized_examples[i]["start_positions"] = cls_index - tokenized_examples[i]["end_positions"] = cls_index - else: - # Start/end character index of the answer in the text. - start_char = answer_starts[0] - end_char = start_char + len(answers[0]) - - # Start token index of the current span in the text. - token_start_index = 0 - while sequence_ids[token_start_index] != 1: - token_start_index += 1 - - # End token index of the current span in the text. - token_end_index = len(input_ids) - 2 - while sequence_ids[token_end_index] != 1: - token_end_index -= 1 - - # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). - if not (offsets[token_start_index][0] <= start_char and - offsets[token_end_index][1] >= end_char): + all_predictions, all_nbest_json, scores_diff_json = compute_prediction( + data_loader.dataset.data, data_loader.dataset.new_data, + (all_start_logits, all_end_logits), args.version_2_with_negative, + args.n_best_size, args.max_answer_length, + args.null_score_diff_threshold) + + # Can also write all_nbest_json and scores_diff_json files if needed + with open('prediction.json', "w", encoding='utf-8') as writer: + writer.write( + json.dumps( + all_predictions, ensure_ascii=False, indent=4) + "\n") + + squad_evaluate( + examples=data_loader.dataset.data, + preds=all_predictions, + na_probs=scores_diff_json) + + model.train() + + +class CrossEntropyLossForSQuAD(paddle.nn.Layer): + def __init__(self): + super(CrossEntropyLossForSQuAD, self).__init__() + + def forward(self, y, label): + start_logits, end_logits = y + start_position, end_position = label + start_position = paddle.unsqueeze(start_position, axis=-1) + end_position = paddle.unsqueeze(end_position, axis=-1) + start_loss = paddle.nn.functional.softmax_with_cross_entropy( + logits=start_logits, label=start_position, soft_label=False) + start_loss = paddle.mean(start_loss) + end_loss = paddle.nn.functional.softmax_with_cross_entropy( + logits=end_logits, label=end_position, soft_label=False) + end_loss = paddle.mean(end_loss) + + loss = (start_loss + end_loss) / 2 + return loss + + +def run(args): + paddle.set_device("gpu" if args.n_gpu else "cpu") + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + set_seed(args) + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if os.path.exists(args.model_name_or_path): + print("init checkpoint from %s" % args.model_name_or_path) + + model = model_class.from_pretrained(args.model_name_or_path) + + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + def prepare_train_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + + contexts = [examples[i]['context'] for i in range(len(examples))] + questions = [examples[i]['question'] for i in range(len(examples))] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + + # Let's label those examples! + + for i, tokenized_example in enumerate(tokenized_examples): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_example["input_ids"] + cls_index = input_ids.index(tokenizer.cls_token_id) + offsets = tokenized_example['offset_mapping'] + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_example['segment_ids'] + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = tokenized_example['overflow_to_sample'] + answers = examples[sample_index]['answers'] + answer_starts = examples[sample_index]['answer_starts'] + # If no answers are given, set the cls_index as answer. + + if len(answer_starts) == 0: tokenized_examples[i]["start_positions"] = cls_index tokenized_examples[i]["end_positions"] = cls_index else: - # Otherwise move the token_start_index and token_end_index to the two ends of the answer. - # Note: we could go after the last offset if the answer is the last word (edge case). - while token_start_index < len(offsets) and offsets[ - token_start_index][0] <= start_char: + # Start/end character index of the answer in the text. + start_char = answer_starts[0] + end_char = start_char + len(answers[0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != 1: token_start_index += 1 - tokenized_examples[i]["start_positions"] = token_start_index - 1 - while offsets[token_end_index][1] >= end_char: + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 2 + while sequence_ids[token_end_index] != 1: token_end_index -= 1 - tokenized_examples[i]["end_positions"] = token_end_index + 1 - return tokenized_examples + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and + offsets[token_end_index][1] >= end_char): + tokenized_examples[i]["start_positions"] = cls_index + tokenized_examples[i]["end_positions"] = cls_index + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[ + token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples[i][ + "start_positions"] = token_start_index - 1 + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples[i]["end_positions"] = token_end_index + 1 + + return tokenized_examples + + if args.do_train: + if args.train_file: + train_ds = load_dataset('sqaud', data_files=args.train_file) + elif args.version_2_with_negative: + train_ds = load_dataset('squad', splits='train_v2') + else: + train_ds = load_dataset('squad', splits='train_v1') + train_ds.map(prepare_train_features, lazy=False) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + train_batchify_fn = lambda samples, fn=Dict({ + "input_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input + "segment_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # segment + "start_positions": Stack(dtype="int64"), # start_pos + "end_positions": Stack(dtype="int64") # end_pos + }): fn(samples) + + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=train_batchify_fn, + return_list=True) + + num_training_steps = args.max_steps if args.max_steps > 0 else len( + train_data_loader) * args.num_train_epochs + + lr_scheduler = LinearDecayWithWarmup( + args.learning_rate, num_training_steps, args.warmup_proportion) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + criterion = CrossEntropyLossForSQuAD() + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, start_positions, end_positions = batch + + logits = model(input_ids=input_ids, token_type_ids=segment_ids) + loss = criterion(logits, (start_positions, end_positions)) + + if global_step % args.logging_steps == 0: + print( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" + % (global_step, epoch, step, loss, + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + + if global_step % args.save_steps == 0: + if (not args.n_gpu > 1 + ) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print('Saving checkpoint to:', output_dir) + + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print('Saving checkpoint to:', output_dir) + + def prepare_validation_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + contexts = [examples[i]['context'] for i in range(len(examples))] + questions = [examples[i]['question'] for i in range(len(examples))] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + for i, tokenized_example in enumerate(tokenized_examples): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_example['segment_ids'] + sample_index = tokenized_example['overflow_to_sample'] + + # One example can give several spans, this is the index of the example containing this span of text. + tokenized_examples[i]["example_id"] = examples[sample_index]['id'] + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples[i]["offset_mapping"] = [ + (o if sequence_ids[k] == 1 else None) + for k, o in enumerate(tokenized_example["offset_mapping"]) + ] + + return tokenized_examples + + if args.do_pred: + if args.predict_file: + dev_ds = load_dataset('sqaud', data_files=args.predict_file) + elif args.version_2_with_negative: + dev_ds = load_dataset('squad', splits='dev_v2') + else: + dev_ds = load_dataset('squad', splits='dev_v1') + dev_ds.map(prepare_validation_features, lazy=False) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) -train_ds.map(prepare_train_features, lazy=False) + dev_batchify_fn = lambda samples, fn=Dict({ + "input_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input + "segment_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]) # segment + }): fn(samples) -print(train_ds[0]) -print(train_ds[1]) -print(len(train_ds)) -print('-----------------------------------------------------') + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=dev_batchify_fn, + return_list=True) -train_batchify_fn = lambda samples, fn=Dict({ - "input_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input - "segment_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # segment - "start_positions": Stack(dtype="int64"), # start_pos - "end_positions": Stack(dtype="int64") # end_pos -}): fn(samples) + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + evaluate(model, dev_data_loader, args) -train_data_loader = DataLoader( - dataset=train_ds, - batch_size=8, - collate_fn=train_batchify_fn, - return_list=True) -for batch in train_data_loader: - print(batch[0]) - print(batch[1]) - print(batch[2]) - print(batch[3]) - break +if __name__ == "__main__": + args = parse_args() + if args.n_gpu > 1: + paddle.distributed.spawn(run, args=(args, ), nprocs=args.n_gpu) + else: + run(args) diff --git a/examples/machine_reading_comprehension/SQuAD/README.md b/examples/machine_reading_comprehension/SQuAD/README.md index 856ef3dfeb5e6..7e582a6cd92d1 100644 --- a/examples/machine_reading_comprehension/SQuAD/README.md +++ b/examples/machine_reading_comprehension/SQuAD/README.md @@ -41,7 +41,7 @@ SQuAD v2.0 ### 数据准备 -为了方便开发者进行测试,我们内置了数据下载脚本,用户可以通过命令行传入`--version_2_with_negative`控制所需要的SQuAD数据集版本,也可以通过`--data_path`传入本地数据集的位置,数据集需保证与SQuAD数据集格式一致。 +为了方便开发者进行测试,我们内置了数据下载脚本,用户可以通过命令行传入`--version_2_with_negative`控制所需要的SQuAD数据集版本,也可以通过`--train_file`和`--prediction_file`传入本地数据集的位置,数据集需保证与SQuAD数据集格式一致。 ### Fine-tune @@ -61,12 +61,16 @@ python -u ./run_squad.py \ --warmup_proportion 0.1 \ --weight_decay 0.01 \ --output_dir ./tmp/squad/ \ + --do_train \ + --do_pred \ --n_gpu 1 ``` * `model_type`: 预训练模型的种类。如bert,ernie,roberta等。 * `model_name_or_path`: 预训练模型的具体名称。如bert-base-uncased,bert-large-cased等。或者是模型文件的本地路径。 * `output_dir`: 保存模型checkpoint的路径。 +* `do_train`: 是否进行训练。 +* `do_pred`: 是否进行预测。 训练结束后模型会自动对结果进行评估,得到类似如下的输出: @@ -97,6 +101,8 @@ python -u ./run_squad.py \ --weight_decay 0.01 \ --output_dir ./tmp/squad/ \ --n_gpu 1 \ + --do_train \ + --do_pred \ --version_2_with_negative ``` diff --git a/examples/machine_reading_comprehension/SQuAD/args.py b/examples/machine_reading_comprehension/SQuAD/args.py index 872f91e1972d8..cc08b8438205f 100644 --- a/examples/machine_reading_comprehension/SQuAD/args.py +++ b/examples/machine_reading_comprehension/SQuAD/args.py @@ -4,10 +4,17 @@ def parse_args(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--data_path", + "--train_file", type=str, + required=False, default=None, - help="Directory of all the data for train, valid, test.") + help="Train data path.") + parser.add_argument( + "--predict_file", + type=str, + required=False, + default=None, + help="Predict data path.") parser.add_argument( "--model_type", default=None, @@ -123,6 +130,9 @@ def parse_args(): action='store_true', help="If true, the SQuAD examples contain some that do not have an answer. If using squad v2.0, it should be set true." ) - + parser.add_argument( + "--do_train", action='store_true', help="Whether to train the model.") + parser.add_argument( + "--do_pred", action='store_true', help="Whether to predict.") args = parser.parse_args() return args diff --git a/examples/machine_reading_comprehension/SQuAD/run_squad.py b/examples/machine_reading_comprehension/SQuAD/run_squad.py index 2e02b2e02c171..9b95913f70491 100644 --- a/examples/machine_reading_comprehension/SQuAD/run_squad.py +++ b/examples/machine_reading_comprehension/SQuAD/run_squad.py @@ -1,21 +1,8 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import collections import os import random import time +import json from functools import partial import numpy as np @@ -26,10 +13,11 @@ import paddlenlp as ppnlp -from paddlenlp.data import Pad, Stack, Tuple +from paddlenlp.data import Pad, Stack, Tuple, Dict from paddlenlp.transformers import BertForQuestionAnswering, BertTokenizer, ErnieForQuestionAnswering, ErnieTokenizer from paddlenlp.transformers import LinearDecayWithWarmup -from paddlenlp.metrics.squad import squad_evaluate, compute_predictions +from paddlenlp.metrics.squad import squad_evaluate, compute_prediction +from paddlenlp.datasets import load_dataset MODEL_CLASSES = { "bert": (BertForQuestionAnswering, BertTokenizer), @@ -43,6 +31,46 @@ def set_seed(args): paddle.seed(args.seed) +def evaluate(model, data_loader, args): + model.eval() + + all_start_logits = [] + all_end_logits = [] + tic_eval = time.time() + + for batch in data_loader: + input_ids, segment_ids = batch + start_logits_tensor, end_logits_tensor = model(input_ids, segment_ids) + + for idx in range(start_logits_tensor.shape[0]): + if len(all_start_logits) % 1000 == 0 and len(all_start_logits): + print("Processing example: %d" % len(all_start_logits)) + print('time per 1000:', time.time() - tic_eval) + tic_eval = time.time() + + all_start_logits.append(start_logits_tensor.numpy()[idx]) + all_end_logits.append(end_logits_tensor.numpy()[idx]) + + all_predictions, all_nbest_json, scores_diff_json = compute_prediction( + data_loader.dataset.data, data_loader.dataset.new_data, + (all_start_logits, all_end_logits), args.version_2_with_negative, + args.n_best_size, args.max_answer_length, + args.null_score_diff_threshold) + + # Can also write all_nbest_json and scores_diff_json files if needed + with open('prediction.json', "w", encoding='utf-8') as writer: + writer.write( + json.dumps( + all_predictions, ensure_ascii=False, indent=4) + "\n") + + squad_evaluate( + examples=data_loader.dataset.data, + preds=all_predictions, + na_probs=scores_diff_json) + + model.train() + + class CrossEntropyLossForSQuAD(paddle.nn.Layer): def __init__(self): super(CrossEntropyLossForSQuAD, self).__init__() @@ -63,46 +91,7 @@ def forward(self, y, label): return loss -def evaluate(model, data_loader, args): - model.eval() - - RawResult = collections.namedtuple( - "RawResult", ["unique_id", "start_logits", "end_logits"]) - - all_results = [] - tic_eval = time.time() - - for batch in data_loader: - input_ids, segment_ids, unipue_ids = batch - start_logits_tensor, end_logits_tensor = model(input_ids, segment_ids) - - for idx in range(unipue_ids.shape[0]): - if len(all_results) % 1000 == 0 and len(all_results): - print("Processing example: %d" % len(all_results)) - print('time per 1000:', time.time() - tic_eval) - tic_eval = time.time() - unique_id = int(unipue_ids[idx]) - start_logits = [float(x) for x in start_logits_tensor.numpy()[idx]] - end_logits = [float(x) for x in end_logits_tensor.numpy()[idx]] - all_results.append( - RawResult( - unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) - - all_predictions, all_nbest_json, scores_diff_json = compute_predictions( - data_loader.dataset.examples, data_loader.dataset.features, all_results, - args.n_best_size, args.max_answer_length, args.do_lower_case, - args.version_2_with_negative, args.null_score_diff_threshold, - args.verbose, data_loader.dataset.tokenizer) - - squad_evaluate(data_loader.dataset.examples, all_predictions, - scores_diff_json, 1.0) - - model.train() - - -def do_train(args): +def run(args): paddle.set_device("gpu" if args.n_gpu else "cpu") if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() @@ -110,115 +99,229 @@ def do_train(args): args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - root = args.data_path set_seed(args) - - train_dataset = ppnlp.datasets.SQuAD( - tokenizer=tokenizer, - doc_stride=args.doc_stride, - root=root, - version_2_with_negative=args.version_2_with_negative, - max_query_length=args.max_query_length, - max_seq_length=args.max_seq_length, - mode="train") - - train_batch_sampler = paddle.io.DistributedBatchSampler( - train_dataset, batch_size=args.batch_size, shuffle=True) - - train_batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input - Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # segment - Stack(), # unipue_id - Stack(dtype="int64"), # start_pos - Stack(dtype="int64") # end_pos - ): [data for i, data in enumerate(fn(samples)) if i != 2] - - train_data_loader = DataLoader( - dataset=train_dataset, - batch_sampler=train_batch_sampler, - collate_fn=train_batchify_fn, - return_list=True) - - dev_dataset = ppnlp.datasets.SQuAD( - tokenizer=tokenizer, - doc_stride=args.doc_stride, - root=root, - version_2_with_negative=args.version_2_with_negative, - max_query_length=args.max_query_length, - max_seq_length=args.max_seq_length, - mode="dev") - - dev_batch_sampler = paddle.io.BatchSampler( - dev_dataset, batch_size=args.batch_size, shuffle=False) - - dev_batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input - Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # segment - Stack() # unipue_id - ): fn(samples) - - dev_data_loader = DataLoader( - dataset=dev_dataset, - batch_sampler=dev_batch_sampler, - collate_fn=dev_batchify_fn, - return_list=True) + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if os.path.exists(args.model_name_or_path): + print("init checkpoint from %s" % args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) - num_training_steps = args.max_steps if args.max_steps > 0 else len( - train_dataset.examples) // args.batch_size * args.num_train_epochs - - lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, - args.warmup_proportion) - - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_scheduler, - epsilon=args.adam_epsilon, - parameters=model.parameters(), - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in [ - p.name for n, p in model.named_parameters() - if not any(nd in n for nd in ["bias", "norm"]) - ]) - criterion = CrossEntropyLossForSQuAD() - - global_step = 0 - tic_train = time.time() - for epoch in range(args.num_train_epochs): - for step, batch in enumerate(train_data_loader): - global_step += 1 - input_ids, segment_ids, start_positions, end_positions = batch - - logits = model(input_ids=input_ids, token_type_ids=segment_ids) - loss = criterion(logits, (start_positions, end_positions)) - - if global_step % args.logging_steps == 0: - print( - "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" - % (global_step, epoch, step, loss, - args.logging_steps / (time.time() - tic_train))) - tic_train = time.time() - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.clear_gradients() - - if global_step % args.save_steps == 0: - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: - output_dir = os.path.join(args.output_dir, - "model_%d" % global_step) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - # need better way to get inner model of DataParallel - model_to_save = model._layers if isinstance( - model, paddle.DataParallel) else model - model_to_save.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - print('Saving checkpoint to:', output_dir) + def prepare_train_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + #NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is + # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead. + contexts = [examples[i]['context'] for i in range(len(examples))] + questions = [examples[i]['question'] for i in range(len(examples))] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length) + + # Let's label those examples! + for i, tokenized_example in enumerate(tokenized_examples): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_example["input_ids"] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offsets = tokenized_example['offset_mapping'] + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_example['segment_ids'] + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = tokenized_example['overflow_to_sample'] + answers = examples[sample_index]['answers'] + answer_starts = examples[sample_index]['answer_starts'] + + # If no answers are given, set the cls_index as answer. + if len(answer_starts) == 0: + tokenized_examples[i]["start_positions"] = cls_index + tokenized_examples[i]["end_positions"] = cls_index + else: + # Start/end character index of the answer in the text. + start_char = answer_starts[0] + end_char = start_char + len(answers[0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != 1: + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 2 + while sequence_ids[token_end_index] != 1: + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and + offsets[token_end_index][1] >= end_char): + tokenized_examples[i]["start_positions"] = cls_index + tokenized_examples[i]["end_positions"] = cls_index + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[ + token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples[i][ + "start_positions"] = token_start_index - 1 + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples[i]["end_positions"] = token_end_index + 1 + + return tokenized_examples + + if args.do_train: + if args.train_file: + train_ds = load_dataset('sqaud', data_files=args.train_file) + elif args.version_2_with_negative: + train_ds = load_dataset('squad', splits='train_v2') + else: + train_ds = load_dataset('squad', splits='train_v1') + train_ds.map(prepare_train_features, lazy=False) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + train_batchify_fn = lambda samples, fn=Dict({ + "input_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input + "segment_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # segment + "start_positions": Stack(dtype="int64"), # start_pos + "end_positions": Stack(dtype="int64") # end_pos + }): fn(samples) + + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=train_batchify_fn, + return_list=True) + + num_training_steps = args.max_steps if args.max_steps > 0 else len( + train_data_loader) * args.num_train_epochs + + lr_scheduler = LinearDecayWithWarmup( + args.learning_rate, num_training_steps, args.warmup_proportion) + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ]) + criterion = CrossEntropyLossForSQuAD() + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, start_positions, end_positions = batch + + logits = model(input_ids=input_ids, token_type_ids=segment_ids) + loss = criterion(logits, (start_positions, end_positions)) + + if global_step % args.logging_steps == 0: + print( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" + % (global_step, epoch, step, loss, + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() + + if global_step % args.save_steps == 0: + if (not args.n_gpu > 1 + ) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print('Saving checkpoint to:', output_dir) + + if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print('Saving checkpoint to:', output_dir) + + def prepare_validation_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + #NOTE: Almost the same functionality as HuggingFace's prepare_train_features function. The main difference is + # that HugggingFace uses ArrowTable as basic data structure, while we use list of dictionary instead. + contexts = [examples[i]['context'] for i in range(len(examples))] + questions = [examples[i]['question'] for i in range(len(examples))] + + tokenized_examples = tokenizer( + questions, + contexts, + stride=args.doc_stride, + max_seq_len=args.max_seq_length) + + # For validation, there is no need to compute start and end positions + for i, tokenized_example in enumerate(tokenized_examples): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_example['segment_ids'] + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = tokenized_example['overflow_to_sample'] + tokenized_examples[i]["example_id"] = examples[sample_index]['id'] + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples[i]["offset_mapping"] = [ + (o if sequence_ids[k] == 1 else None) + for k, o in enumerate(tokenized_example["offset_mapping"]) + ] + + return tokenized_examples + + if args.do_pred: + if args.predict_file: + dev_ds = load_dataset('sqaud', data_files=args.predict_file) + elif args.version_2_with_negative: + dev_ds = load_dataset('squad', splits='dev_v2') + else: + dev_ds = load_dataset('squad', splits='dev_v1') + + dev_ds.map(prepare_validation_features, lazy=False) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + + dev_batchify_fn = lambda samples, fn=Dict({ + "input_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]), # input + "segment_ids": Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]) # segment + }): fn(samples) + + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=dev_batchify_fn, + return_list=True) if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: evaluate(model, dev_data_loader, args) @@ -227,6 +330,6 @@ def do_train(args): if __name__ == "__main__": args = parse_args() if args.n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) + paddle.distributed.spawn(run, args=(args, ), nprocs=args.n_gpu) else: - do_train(args) + run(args) diff --git a/paddlenlp/datasets/experimental/dataset.py b/paddlenlp/datasets/experimental/dataset.py index f271920a35c14..2105c7654c64a 100644 --- a/paddlenlp/datasets/experimental/dataset.py +++ b/paddlenlp/datasets/experimental/dataset.py @@ -67,7 +67,10 @@ def load_dataset(name, data_files=None, splits=None, lazy=None): ) or ( isinstance(data_files, tuple) and isinstance(data_files[0], str) ), "`data_files` should be a string or list of string or a tuple of string." - datasets += reader_instance.read_datasets(data_files) + if isinstance(data_files, str): + datasets += reader_instance.read_datasets([data_files]) + else: + datasets += reader_instance.read_datasets(data_files) if splits: assert isinstance(splits, str) or ( @@ -75,9 +78,12 @@ def load_dataset(name, data_files=None, splits=None, lazy=None): ) or ( isinstance(splits, tuple) and isinstance(splits[0], str) ), "`splits` should be a string or list of string or a tuple of string." - datasets += reader_instance.read_datasets(splits) + if isinstance(splits, str): + datasets += reader_instance.read_datasets([splits]) + else: + datasets += reader_instance.read_datasets(splits) - return datasets + return datasets if len(datasets) > 1 else datasets[0] @classmethod diff --git a/paddlenlp/datasets/experimental/squad.py b/paddlenlp/datasets/experimental/squad.py index 892b591f7b7fc..05dda9c76840b 100644 --- a/paddlenlp/datasets/experimental/squad.py +++ b/paddlenlp/datasets/experimental/squad.py @@ -54,6 +54,8 @@ def _read(self, filename): for qa in paragraph["qas"]: qas_id = qa["id"] question = qa["question"].strip() + answer_starts = [] + answers = [] is_impossible = False if "is_impossible" in qa.keys(): diff --git a/paddlenlp/metrics/squad.py b/paddlenlp/metrics/squad.py index da5d11f00ed85..9aada6e928b0e 100644 --- a/paddlenlp/metrics/squad.py +++ b/paddlenlp/metrics/squad.py @@ -1,10 +1,18 @@ -"""Official evaluation script for SQuAD version 2.0. +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -In addition to basic functionality, we also compute additional statistics and -plot precision-recall curves if an additional na_prob.json file is provided. -This file is expected to map question ID's to the model's predicted probability -that a question is unanswerable. -""" import collections import re import string @@ -14,335 +22,198 @@ import math -def compute_predictions(all_examples, - all_features, - all_results, - n_best_size, - max_answer_length, - do_lower_case, - version_2_with_negative, - null_score_diff_threshold, - verbose, - tokenizer, - is_whitespace_splited=True): - """Write final predictions to the json file and log-odds of null if needed.""" - - example_index_to_features = collections.defaultdict(list) - for feature in all_features: - example_index_to_features[feature.example_index].append(feature) - - unique_id_to_result = {} - for result in all_results: - unique_id_to_result[result.unique_id] = result - - _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name - "PrelimPrediction", [ - "feature_index", "start_index", "end_index", "start_logit", - "end_logit" - ]) - +def compute_prediction(examples, + features, + predictions, + version_2_with_negative: bool=False, + n_best_size: int=20, + max_answer_length: int=30, + null_score_diff_threshold: float=0.0, + is_whitespace_splited=True): + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + """ + assert len( + predictions + ) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len( + features), "Number of predictions should be equal to number of features." + + # Build a map example to its corresponding features. + example_id_to_index = {k['id']: i for i, k in enumerate(examples)} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append( + i) + + # The dictionaries we have to fill. all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() - print(len(unique_id_to_result)) - for (example_index, example) in enumerate(all_examples): - features = example_index_to_features[example_index] + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + min_null_prediction = None prelim_predictions = [] - # keep track of the minimum score of null start+end of position 0 - score_null = 1000000 # large and positive - min_null_feature_index = 0 # the paragraph slice with min mull score - null_start_logit = 0 # the start logit at the slice with min null score - null_end_logit = 0 # the end logit at the slice with min null score - for (feature_index, feature) in enumerate(features): - result = unique_id_to_result[feature.unique_id] - start_indexes = _get_best_indexes(result.start_logits, n_best_size) - end_indexes = _get_best_indexes(result.end_logits, n_best_size) - # if we could have irrelevant answers, get the min score of irrelevant - if version_2_with_negative: - feature_null_score = result.start_logits[0] + result.end_logits[ - 0] - if feature_null_score < score_null: - score_null = feature_null_score - min_null_feature_index = feature_index - null_start_logit = result.start_logits[0] - null_end_logit = result.end_logits[0] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get( + "token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction[ + "score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1:-n_best_size - 1: + -1].tolist() + end_indexes = np.argsort(end_logits)[-1:-n_best_size - 1:-1].tolist( + ) for start_index in start_indexes: for end_index in end_indexes: - # We could hypothetically create invalid predictions, e.g., predict - # that the start of the span is in the question. We throw out all - # invalid predictions. - if start_index >= len(feature.tokens): - continue - if end_index >= len(feature.tokens): + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if (start_index >= len(offset_mapping) or + end_index >= len(offset_mapping) or + offset_mapping[start_index] is None or + offset_mapping[end_index] is None): continue - if start_index not in feature.token_to_orig_map: + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: continue - if end_index not in feature.token_to_orig_map: + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get( + str(start_index), False): continue - if not feature.token_is_max_context.get(start_index, False): - continue - if end_index < start_index: - continue - length = end_index - start_index + 1 - if length > max_answer_length: - continue - prelim_predictions.append( - _PrelimPrediction( - feature_index=feature_index, - start_index=start_index, - end_index=end_index, - start_logit=result.start_logits[start_index], - end_logit=result.end_logits[end_index])) - + prelim_predictions.append({ + "offsets": (offset_mapping[start_index][0], + offset_mapping[end_index][1]), + "score": + start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + }) if version_2_with_negative: - prelim_predictions.append( - _PrelimPrediction( - feature_index=min_null_feature_index, - start_index=0, - end_index=0, - start_logit=null_start_logit, - end_logit=null_end_logit)) - prelim_predictions = sorted( - prelim_predictions, - key=lambda x: (x.start_logit + x.end_logit), - reverse=True) - - _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name - "NbestPrediction", ["text", "start_logit", "end_logit"]) - - seen_predictions = {} - nbest = [] - for pred in prelim_predictions: - if len(nbest) >= n_best_size: - break - feature = features[pred.feature_index] - if pred.start_index > 0: # this is a non-null prediction - tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1 - )] - orig_doc_start = feature.token_to_orig_map[pred.start_index] - orig_doc_end = feature.token_to_orig_map[pred.end_index] - orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + - 1)] - tok_text = " ".join(tok_tokens) - - # De-tokenize WordPieces that have been split off. - tok_text = tok_text.replace(" ##", "") - tok_text = tok_text.replace("##", "") - - # Clean whitespace - tok_text = tok_text.strip() - tok_text = " ".join(tok_text.split()) - orig_text = " ".join(orig_tokens) - final_text = get_final_text(tok_text, orig_text, tokenizer, - verbose) - if not is_whitespace_splited: - final_text = final_text.replace(' ', '') - if final_text in seen_predictions: - continue - - seen_predictions[final_text] = True - else: - final_text = "" - seen_predictions[final_text] = True - - nbest.append( - _NbestPrediction( - text=final_text, - start_logit=pred.start_logit, - end_logit=pred.end_logit)) - - # if we didn't inlude the empty option in the n-best, inlcude it - if version_2_with_negative: - if "" not in seen_predictions: - nbest.append( - _NbestPrediction( - text="", - start_logit=null_start_logit, - end_logit=null_end_logit)) - # In very rare edge cases we could have no valid predictions. So we - # just create a nonce prediction in this case to avoid failure. - if not nbest: - nbest.append( - _NbestPrediction( - text="empty", start_logit=0.0, end_logit=0.0)) - - assert len(nbest) >= 1 - - total_scores = [] - best_non_null_entry = None - for entry in nbest: - total_scores.append(entry.start_logit + entry.end_logit) - if not best_non_null_entry: - if entry.text: - best_non_null_entry = entry - else: - best_non_null_entry = _NbestPrediction( - text="empty", start_logit=0.0, end_logit=0.0) - - probs = _compute_softmax(total_scores) - - nbest_json = [] - for (i, entry) in enumerate(nbest): - output = collections.OrderedDict() - output["text"] = entry.text - output["probability"] = probs[i] - output["start_logit"] = entry.start_logit - output["end_logit"] = entry.end_logit - nbest_json.append(output) - - assert len(nbest_json) >= 1 + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted( + prelim_predictions, key=lambda x: x["score"], + reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if version_2_with_negative and not any(p["offsets"] == (0, 0) + for p in predictions): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0]:offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and + predictions[0]["text"] == ""): + predictions.insert(0, { + "text": "empty", + "start_logit": 0.0, + "end_logit": 0.0, + "score": 0.0 + }) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. if not version_2_with_negative: - all_predictions[example.qas_id] = nbest_json[0]["text"] + all_predictions[example["id"]] = predictions[0]["text"] else: - - # predict "" iff the null score - the score of best non-null > threshold - score_diff = score_null - best_non_null_entry.start_logit - ( - best_non_null_entry.end_logit) - scores_diff_json[example.qas_id] = score_diff + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred[ + "start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float( + score_diff) # To be JSON-serializable. if score_diff > null_score_diff_threshold: - all_predictions[example.qas_id] = "" + all_predictions[example["id"]] = "" else: - all_predictions[example.qas_id] = best_non_null_entry.text + all_predictions[example["id"]] = best_non_null_pred["text"] - all_nbest_json[example.qas_id] = nbest_json + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [{ + k: (float(v) + if isinstance(v, (np.float16, np.float32, np.float64)) else v) + for k, v in pred.items() + } for pred in predictions] return all_predictions, all_nbest_json, scores_diff_json -def get_final_text(pred_text, orig_text, tokenizer, verbose): - """Project the tokenized prediction back to the original text.""" - - # When we created the data, we kept track of the alignment between original - # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So - # now `orig_text` contains the span of our original text corresponding to the - # span that we predicted. - # - # However, `orig_text` may contain extra characters that we don't want in - # our prediction. - # - # For example, let's say: - # pred_text = steve smith - # orig_text = Steve Smith's - # - # We don't want to return `orig_text` because it contains the extra "'s". - # - # We don't want to return `pred_text` because it's already been normalized - # (the SQuAD eval script also does punctuation stripping/lower casing but - # our tokenizer does additional normalization like stripping accent - # characters). - # - # What we really want to return is "Steve Smith". - # - # Therefore, we have to apply a semi-complicated alignment heruistic between - # `pred_text` and `orig_text` to get a character-to-charcter alignment. This - # can fail in certain cases in which case we just return `orig_text`. - - def _strip_spaces(text): - ns_chars = [] - ns_to_s_map = collections.OrderedDict() - for (i, c) in enumerate(text): - if c == " ": - continue - ns_to_s_map[len(ns_chars)] = i - ns_chars.append(c) - ns_text = "".join(ns_chars) - return (ns_text, ns_to_s_map) - - # We first tokenize `orig_text`, strip whitespace from the result - # and `pred_text`, and check if they are the same length. If they are - # NOT the same length, the heuristic has failed. If they are the same - # length, we assume the characters are one-to-one aligned. - tok_text = " ".join(tokenizer.basic_tokenizer.tokenize(orig_text)) - start_position = tok_text.find(pred_text) - if start_position == -1: - if verbose: - print(u"Unable to find text: '%s' in '%s'" % (pred_text, tok_text)) - return orig_text - end_position = start_position + len(pred_text) - 1 - - (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) - (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) - - if len(orig_ns_text) != len(tok_ns_text): - if verbose: - print(u"Length not equal after stripping spaces: '%s' vs '%s'" % - (orig_ns_text, tok_ns_text)) - return orig_text - - # We then project the characters in `pred_text` back to `orig_text` using - # the character-to-character alignment. - tok_s_to_ns_map = {} - for i, tok_index in tok_ns_to_s_map.items(): - tok_s_to_ns_map[tok_index] = i - - orig_start_position = None - if start_position in tok_s_to_ns_map: - ns_start_position = tok_s_to_ns_map[start_position] - if ns_start_position in orig_ns_to_s_map: - orig_start_position = orig_ns_to_s_map[ns_start_position] - - if orig_start_position is None: - if verbose: - print(u"Couldn't map start position") - return orig_text - - orig_end_position = None - if end_position in tok_s_to_ns_map: - ns_end_position = tok_s_to_ns_map[end_position] - if ns_end_position in orig_ns_to_s_map: - orig_end_position = orig_ns_to_s_map[ns_end_position] - - if orig_end_position is None: - if verbose: - print(u"Couldn't map end position") - return orig_text - - output_text = orig_text[orig_start_position:(orig_end_position + 1)] - return output_text - - -def _compute_softmax(scores): - """Compute softmax probability over raw logits.""" - if not scores: - return [] - - max_score = None - for score in scores: - if max_score is None or score > max_score: - max_score = score - - exp_scores = [] - total_sum = 0.0 - for score in scores: - x = math.exp(score - max_score) - exp_scores.append(x) - total_sum += x - - probs = [] - for score in exp_scores: - probs.append(score / total_sum) - return probs - - -def _get_best_indexes(logits, n_best_size): - """Get the n-best logits from a list.""" - index_and_score = sorted( - enumerate(logits), key=lambda x: x[1], reverse=True) - - best_indexes = [] - for i in range(len(index_and_score)): - if i >= n_best_size: - break - best_indexes.append(index_and_score[i][0]) - return best_indexes - - def make_qid_to_has_ans(examples): qid_to_has_ans = {} for example in examples: - qid_to_has_ans[example.qas_id] = not example.is_impossible + qid_to_has_ans[example['id']] = not example.get('is_impossible', False) return qid_to_has_ans @@ -364,7 +235,7 @@ def lower(text): return text.lower() if not s: - return [] + return '' else: return white_space_fix(remove_articles(remove_punc(lower(s)))) @@ -398,9 +269,9 @@ def get_raw_scores(examples, preds, is_whitespace_splited=True): exact_scores = {} f1_scores = {} for example in examples: - qid = example.qas_id + qid = example['id'] gold_answers = [ - text for text in example.orig_answer_text if normalize_answer(text) + text for text in example['answers'] if normalize_answer(text) ] if not gold_answers: # For unanswerable questions, only correct answer is empty string @@ -409,7 +280,6 @@ def get_raw_scores(examples, preds, is_whitespace_splited=True): print('Missing prediction for %s' % qid) continue a_pred = preds[qid] - # Take max over all gold answers exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) f1_scores[qid] = max( diff --git a/paddlenlp/transformers/roberta/tokenizer.py b/paddlenlp/transformers/roberta/tokenizer.py index d07ad54d88361..4960f595b2d7d 100644 --- a/paddlenlp/transformers/roberta/tokenizer.py +++ b/paddlenlp/transformers/roberta/tokenizer.py @@ -117,7 +117,7 @@ def _tokenize(self, text): split_tokens.append(sub_token) return split_tokens - def __call__(self, text): + def tokenize(self, text): """ End-to-end tokenization for RoBERTa models. Args: