From 8e85947c97c661f027063b65d04883f7b86e3df0 Mon Sep 17 00:00:00 2001 From: gongel Date: Fri, 29 Jul 2022 11:01:14 +0000 Subject: [PATCH 1/5] add hf ds and upgrade example --- model_zoo/ernie-m/run_classifier.py | 80 ++++---- paddlenlp/datasets/hf_datasets/xnli.py | 251 +++++++++++++++++++++++++ 2 files changed, 294 insertions(+), 37 deletions(-) create mode 100644 paddlenlp/datasets/hf_datasets/xnli.py diff --git a/model_zoo/ernie-m/run_classifier.py b/model_zoo/ernie-m/run_classifier.py index 0f7fbb63fc05..9039fd508e50 100644 --- a/model_zoo/ernie-m/run_classifier.py +++ b/model_zoo/ernie-m/run_classifier.py @@ -26,11 +26,11 @@ from paddle.io import Dataset, BatchSampler, DistributedBatchSampler, DataLoader from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer from paddlenlp.transformers import LinearDecayWithWarmup -from paddlenlp.datasets import load_dataset -from paddlenlp.data import Stack, Tuple, Pad +from datasets import load_dataset from paddle.metric import Accuracy from paddlenlp.ops.optimizer import layerwise_lr_decay from paddle.optimizer import AdamW +from paddlenlp.data import DataCollatorWithPadding all_languages = [ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", @@ -137,6 +137,9 @@ def parse_args(): type=str, choices=["cpu", "gpu", "xpu"], help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument("--overwrite_cache", + action="store_true", + help="Whether to overwrite cache for dataset.") parser.add_argument("--use_amp", type=distutils.util.strtobool, default=False, @@ -164,8 +167,8 @@ def evaluate(model, loss_fct, metric, data_loader, language): model.eval() metric.reset() for batch in data_loader: - input_ids, position_ids, attention_mask, labels = batch - logits = model(input_ids, position_ids, attention_mask) + labels = batch.pop("labels") + logits = model(**batch) loss = loss_fct(logits, labels) correct = metric.compute(logits, labels) metric.update(correct) @@ -178,21 +181,25 @@ def evaluate(model, loss_fct, metric, data_loader, language): def convert_example(example, tokenizer, max_seq_length=256): """convert a example into necessary features""" - # Get the label - label = example["label"] - premise = example["premise"] - hypothesis = example["hypothesis"] # Convert raw text to feature - example = tokenizer(premise, - text_pair=hypothesis, - max_seq_len=max_seq_length) - return example["input_ids"], example["position_ids"], example[ - "attention_mask"], label - - -def get_test_dataloader(args, language, batchify_fn, trans_func): - test_ds = load_dataset("xnli", language, splits="test") - test_ds = test_ds.map(trans_func, lazy=True) + tokenized_example = tokenizer(example["premise"], + text_pair=example["hypothesis"], + max_length=max_seq_length, + padding=False, + truncation=True, + return_position_ids=True, + return_attention_mask=True, + return_token_type_ids=False) + return tokenized_example + + +def get_test_dataloader(args, language, batchify_fn, trans_func, + remove_columns): + test_ds = load_dataset("xnli", language, split="test") + test_ds = test_ds.map(trans_func, + batched=True, + remove_columns=remove_columns, + load_from_cache_file=not args.overwrite_cache) test_batch_sampler = BatchSampler(test_ds, batch_size=args.batch_size, shuffle=False) @@ -220,11 +227,7 @@ def __getitem__(self, idx): last = language_idx - 1 if language_idx > 0 else language_idx sample_idx = idx - self.cumsum_len[last] if idx >= self.cumsum_len[ last] else idx - input_ids = self.datasets[language_idx][sample_idx][0] - position_ids = self.datasets[language_idx][sample_idx][1] - attention_mask = self.datasets[language_idx][sample_idx][2] - label = self.datasets[language_idx][sample_idx][3] - return input_ids, position_ids, attention_mask, label + return self.datasets[int(language_idx)][int(sample_idx)] def __len__(self): return self.cumsum_len[-1] @@ -240,25 +243,28 @@ def do_train(args): trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length) + remove_columns = ["premise", "hypothesis"] if args.task_type == "cross-lingual-transfer": - train_ds = load_dataset("xnli", "en", splits="train") - train_ds = train_ds.map(trans_func, lazy=True) + train_ds = load_dataset("xnli", "en", split="train") + train_ds = train_ds.map(trans_func, + batched=True, + remove_columns=remove_columns, + load_from_cache_file=not args.overwrite_cache) elif args.task_type == "translate-train-all": all_train_ds = [] for language in all_languages: - train_ds = load_dataset("xnli", language, splits="train") - all_train_ds.append(train_ds.map(trans_func, lazy=True)) + train_ds = load_dataset("xnli", language, split="train") + all_train_ds.append( + train_ds.map(trans_func, + batched=True, + remove_columns=remove_columns, + load_from_cache_file=not args.overwrite_cache)) train_ds = XnliDataset(all_train_ds) train_batch_sampler = DistributedBatchSampler(train_ds, batch_size=args.batch_size, shuffle=True) - batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input_ids - Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64" - ), # position_ids - Pad(axis=0, pad_val=0, dtype="int64"), # attention_mask - Stack(dtype="int64") # labels - ): fn(samples) + batchify_fn = DataCollatorWithPadding(tokenizer) + train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, @@ -318,11 +324,11 @@ def do_train(args): for epoch in range(num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 - input_ids, position_ids, attention_mask, labels = batch + labels = batch.pop("labels") with paddle.amp.auto_cast( args.use_amp, custom_white_list=["layer_norm", "softmax", "gelu"]): - logits = model(input_ids, position_ids, attention_mask) + logits = model(**batch) loss = loss_fct(logits, labels) if args.use_amp: scaled_loss = scaler.scale(loss) @@ -344,7 +350,7 @@ def do_train(args): for language in all_languages: tic_eval = time.time() test_data_loader = get_test_dataloader( - args, language, batchify_fn, trans_func) + args, language, batchify_fn, trans_func, remove_columns) evaluate(model, loss_fct, metric, test_data_loader, language) print("eval done total : %s s" % (time.time() - tic_eval)) diff --git a/paddlenlp/datasets/hf_datasets/xnli.py b/paddlenlp/datasets/hf_datasets/xnli.py new file mode 100644 index 000000000000..de461ac00b0d --- /dev/null +++ b/paddlenlp/datasets/hf_datasets/xnli.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""XNLI: The Cross-Lingual NLI Corpus.""" + +import collections +import csv +import os +from contextlib import ExitStack + +import datasets + +_CITATION = """\ +@InProceedings{conneau2018xnli, + author = {Conneau, Alexis + and Rinott, Ruty + and Lample, Guillaume + and Williams, Adina + and Bowman, Samuel R. + and Schwenk, Holger + and Stoyanov, Veselin}, + title = {XNLI: Evaluating Cross-lingual Sentence Representations}, + booktitle = {Proceedings of the 2018 Conference on Empirical Methods + in Natural Language Processing}, + year = {2018}, + publisher = {Association for Computational Linguistics}, + location = {Brussels, Belgium}, +}""" + +_DESCRIPTION = """\ +XNLI is a subset of a few thousand examples from MNLI which has been translated +into a 14 different languages (some low-ish resource). As with MNLI, the goal is +to predict textual entailment (does sentence A imply/contradict/neither sentence +B) and is a classification task (given two sentences, predict one of three +labels). +""" + +_TRAIN_DATA_URL = "https://bj.bcebos.com/paddlenlp/datasets/XNLI-MT-1.0.zip" +_TESTVAL_DATA_URL = "https://bj.bcebos.com/paddlenlp/datasets/XNLI-1.0.zip" + +_LANGUAGES = ("ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", + "tr", "ur", "vi", "zh") + + +class XnliConfig(datasets.BuilderConfig): + """BuilderConfig for XNLI.""" + + def __init__(self, language: str, languages=None, **kwargs): + """BuilderConfig for XNLI. + + Args: + language: One of ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh, or all_languages + **kwargs: keyword arguments forwarded to super. + """ + super(XnliConfig, self).__init__(**kwargs) + self.language = language + if language != "all_languages": + self.languages = [language] + else: + self.languages = languages if languages is not None else _LANGUAGES + + +class Xnli(datasets.GeneratorBasedBuilder): + """XNLI: The Cross-Lingual NLI Corpus. Version 1.0.""" + + VERSION = datasets.Version("1.1.0", "") + BUILDER_CONFIG_CLASS = XnliConfig + BUILDER_CONFIGS = [ + XnliConfig( + name=lang, + language=lang, + version=datasets.Version("1.1.0", ""), + description=f"Plain text import of XNLI for the {lang} language", + ) for lang in _LANGUAGES + ] + [ + XnliConfig( + name="all_languages", + language="all_languages", + version=datasets.Version("1.1.0", ""), + description="Plain text import of XNLI for all languages", + ) + ] + + def _info(self): + if self.config.language == "all_languages": + features = datasets.Features({ + "premise": + datasets.Translation(languages=_LANGUAGES, ), + "hypothesis": + datasets.TranslationVariableLanguages(languages=_LANGUAGES, ), + "label": + datasets.ClassLabel( + names=["entailment", "neutral", "contradiction"]), + }) + else: + features = datasets.Features({ + "premise": + datasets.Value("string"), + "hypothesis": + datasets.Value("string"), + "label": + datasets.ClassLabel( + names=["entailment", "neutral", "contradiction"]), + }) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + # No default supervised_keys (as we have to pass both premise + # and hypothesis as input). + supervised_keys=None, + homepage="https://www.nyu.edu/projects/bowman/xnli/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + dl_dirs = dl_manager.download_and_extract({ + "train_data": + _TRAIN_DATA_URL, + "testval_data": + _TESTVAL_DATA_URL, + }) + train_dir = os.path.join(dl_dirs["train_data"], "XNLI-MT-1.0", + "multinli") + testval_dir = os.path.join(dl_dirs["testval_data"], "XNLI-1.0") + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepaths": [ + os.path.join(train_dir, f"multinli.train.{lang}.tsv") + for lang in self.config.languages + ], + "data_format": + "XNLI-MT", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepaths": [os.path.join(testval_dir, "xnli.test.tsv")], + "data_format": "XNLI" + }, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "filepaths": [os.path.join(testval_dir, "xnli.dev.tsv")], + "data_format": "XNLI" + }, + ), + ] + + def _generate_examples(self, data_format, filepaths): + """This function returns the examples in the raw (text) form.""" + + if self.config.language == "all_languages": + if data_format == "XNLI-MT": + with ExitStack() as stack: + files = [ + stack.enter_context(open(filepath, encoding="utf-8")) + for filepath in filepaths + ] + readers = [ + csv.DictReader(file, + delimiter="\t", + quoting=csv.QUOTE_NONE) for file in files + ] + for row_idx, rows in enumerate(zip(*readers)): + yield row_idx, { + "premise": { + lang: row["premise"] + for lang, row in zip(self.config.languages, + rows) + }, + "hypothesis": { + lang: row["hypo"] + for lang, row in zip(self.config.languages, + rows) + }, + "label": + rows[0]["label"].replace("contradictory", + "contradiction"), + } + else: + rows_per_pair_id = collections.defaultdict(list) + for filepath in filepaths: + with open(filepath, encoding="utf-8") as f: + reader = csv.DictReader(f, + delimiter="\t", + quoting=csv.QUOTE_NONE) + for row in reader: + rows_per_pair_id[row["pairID"]].append(row) + + for rows in rows_per_pair_id.values(): + premise = { + row["language"]: row["sentence1"] + for row in rows + } + hypothesis = { + row["language"]: row["sentence2"] + for row in rows + } + yield rows[0]["pairID"], { + "premise": premise, + "hypothesis": hypothesis, + "label": rows[0]["gold_label"], + } + else: + if data_format == "XNLI-MT": + for file_idx, filepath in enumerate(filepaths): + with open(filepath, encoding="utf-8") as file: + reader = csv.DictReader(file, + delimiter="\t", + quoting=csv.QUOTE_NONE) + for row_idx, row in enumerate(reader): + key = str(file_idx) + "_" + str(row_idx) + yield key, { + "premise": + row["premise"], + "hypothesis": + row["hypo"], + "label": + row["label"].replace("contradictory", + "contradiction"), + } + else: + for filepath in filepaths: + with open(filepath, encoding="utf-8") as f: + reader = csv.DictReader(f, + delimiter="\t", + quoting=csv.QUOTE_NONE) + for row in reader: + if row["language"] == self.config.language: + yield row["pairID"], { + "premise": row["sentence1"], + "hypothesis": row["sentence2"], + "label": row["gold_label"], + } From f560a6b291f4b9c0dab2103e5a129290320ca864 Mon Sep 17 00:00:00 2001 From: gongel Date: Fri, 9 Sep 2022 11:29:19 +0000 Subject: [PATCH 2/5] fix attention mask --- model_zoo/gpt/dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/model_zoo/gpt/dataset.py b/model_zoo/gpt/dataset.py index 88d4c15deec9..160f79c07a6d 100755 --- a/model_zoo/gpt/dataset.py +++ b/model_zoo/gpt/dataset.py @@ -442,17 +442,14 @@ def _construct_sample(self, tokens): labels = tokens[1:] tokens = tokens[:-1] seq_length = len(tokens) - # Attention mask for the attention calulate - attention_mask = np.tri(seq_length, seq_length).reshape( - (1, seq_length, seq_length)) + # No padding, so attention_mask is None + attention_mask = None # The pad and eos tokens do not contribute the loss loss_mask = np.ones(seq_length, dtype="float32") loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0 position_ids = np.arange(0, seq_length, dtype="int64") - attention_mask = (attention_mask - 1.0) * 1e9 - attention_mask = attention_mask.astype("float32") labels = np.array(labels, dtype="int64") return [tokens, loss_mask, attention_mask, position_ids, labels] From 28ea1e2f8ba456fbb2ad6fbafa5114f53850e48d Mon Sep 17 00:00:00 2001 From: gongel Date: Sat, 10 Sep 2022 01:31:50 +0000 Subject: [PATCH 3/5] update --- model_zoo/gpt/dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model_zoo/gpt/dataset.py b/model_zoo/gpt/dataset.py index 160f79c07a6d..d5d706a12251 100755 --- a/model_zoo/gpt/dataset.py +++ b/model_zoo/gpt/dataset.py @@ -442,14 +442,13 @@ def _construct_sample(self, tokens): labels = tokens[1:] tokens = tokens[:-1] seq_length = len(tokens) - # No padding, so attention_mask is None - attention_mask = None # The pad and eos tokens do not contribute the loss loss_mask = np.ones(seq_length, dtype="float32") loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0 position_ids = np.arange(0, seq_length, dtype="int64") + attention_mask = loss_mask labels = np.array(labels, dtype="int64") return [tokens, loss_mask, attention_mask, position_ids, labels] From 2e44dda63b76934149e535fb9cc501b0f6541064 Mon Sep 17 00:00:00 2001 From: gongel Date: Mon, 26 Sep 2022 04:05:09 +0000 Subject: [PATCH 4/5] update attention mask --- model_zoo/gpt/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_zoo/gpt/dataset.py b/model_zoo/gpt/dataset.py index d5d706a12251..148d581a1b7f 100755 --- a/model_zoo/gpt/dataset.py +++ b/model_zoo/gpt/dataset.py @@ -448,7 +448,7 @@ def _construct_sample(self, tokens): loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0 position_ids = np.arange(0, seq_length, dtype="int64") - attention_mask = loss_mask + attention_mask = np.ones(seq_length, dtype="int64") labels = np.array(labels, dtype="int64") return [tokens, loss_mask, attention_mask, position_ids, labels] From 50f7e8623275ccc1ed6fcda97f86691e5d7fec77 Mon Sep 17 00:00:00 2001 From: gongel Date: Mon, 26 Sep 2022 05:10:20 +0000 Subject: [PATCH 5/5] fix static attention mask --- model_zoo/gpt/run_pretrain_static.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model_zoo/gpt/run_pretrain_static.py b/model_zoo/gpt/run_pretrain_static.py index 7fed8e3ab211..fd25ab74c187 100644 --- a/model_zoo/gpt/run_pretrain_static.py +++ b/model_zoo/gpt/run_pretrain_static.py @@ -53,10 +53,9 @@ def create_data_holder(args): loss_mask = paddle.static.data(name="loss_mask", shape=[-1, args.max_seq_len], dtype="float32") - attention_mask = paddle.static.data( - name="attention_mask", - shape=[-1, 1, args.max_seq_len, args.max_seq_len], - dtype="float32") + attention_mask = paddle.static.data(name="attention_mask", + shape=[-1, args.max_seq_len], + dtype="int64") position_ids = paddle.static.data(name="position_ids", shape=[-1, args.max_seq_len], dtype="int64")