From b05b4cae5d85986e484e70174b2af275538a9b6d Mon Sep 17 00:00:00 2001 From: david Date: Tue, 14 Mar 2023 14:23:43 +0100 Subject: [PATCH] fix: resolved #2483 --- src/argilla/client/datasets.py | 13 ++++++++++--- tests/client/test_dataset.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index fd1eb386bd..bdfd0cffc6 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -720,13 +720,14 @@ def _prepare_for_training_with_transformers( inputs_keys = {key: None for rec in self._records for key in rec.inputs if rec.annotation is not None}.keys() - ds_dict = {**{key: [] for key in inputs_keys}, "label": []} + ds_dict = {"id": [], **{key: [] for key in inputs_keys}, "label": []} for rec in self._records: if rec.annotation is None: continue for key in inputs_keys: ds_dict[key].append(rec.inputs.get(key)) ds_dict["label"].append(rec.annotation) + ds_dict["id"].append(str(rec.id)) if self._records[0].multi_label: labels = {label: None for labels in ds_dict["label"] for label in labels} @@ -741,6 +742,7 @@ def _prepare_for_training_with_transformers( ) feature_dict = { + "id": datasets.Value("string"), **{key: datasets.Value("string") for key in inputs_keys}, "label": [class_label] if self._records[0].multi_label else class_label, } @@ -757,6 +759,7 @@ def _prepare_for_training_with_transformers( feature_dict["binarized_label"] = feature_dict["label"] ds = datasets.Dataset.from_dict( { + "id": ds["id"], "text": ds["text"], "context": ds_dict["context"], "label": labels, @@ -787,6 +790,7 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[ else: text = record.text doc = nlp.make_doc(text) + doc.user_data["id"] = record.id cats = dict.fromkeys(all_labels, 0) @@ -990,7 +994,7 @@ def spans2iob(example): new_features = ds.features.copy() new_features["ner_tags"] = datasets.Sequence(feature=class_tags) ds = ds.cast(new_features) - ds = ds.remove_columns(set(ds.column_names) - set(["tokens", "ner_tags"])) + ds = ds.remove_columns(set(ds.column_names) - set(["id", "tokens", "ner_tags"])) if test_size is not None and test_size != 0: ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed) @@ -1009,6 +1013,7 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[ continue doc = nlp.make_doc(record.text) + doc.user_data["id"] = record.id entities = [] for anno in record.annotation: @@ -1250,14 +1255,16 @@ def _prepare_for_training_with_transformers( ): import datasets - ds_dict = {"text": [], "target": []} + ds_dict = {"id": [], "text": [], "target": []} for rec in self._records: if rec.annotation is None: continue + ds_dict["id"].append(rec.id) ds_dict["text"].append(rec.text) ds_dict["target"].append(rec.annotation) feature_dict = { + "id": datasets.Value("string"), "text": datasets.Value("string"), "target": datasets.Value("string"), } diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 163460d44d..c2e707d9b4 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -342,9 +342,9 @@ def test_prepare_for_training(self, request, records): train = ds.prepare_for_training(seed=42) if not ds[0].multi_label: - column_names = ["text", "context", "label"] + column_names = ["id", "text", "context", "label"] else: - column_names = ["text", "context", "label", "binarized_label"] + column_names = ["id", "text", "context", "label", "binarized_label"] assert isinstance(train, datasets.Dataset) assert train.column_names == column_names @@ -361,7 +361,7 @@ def test_prepare_for_training(self, request, records): assert len(train_test["train"]) == 1 assert len(train_test["test"]) == 1 for split in ["train", "test"]: - assert train_test[split].column_names == column_names + assert set(train_test[split].column_names) == set(column_names) @pytest.mark.parametrize( "records", @@ -393,6 +393,8 @@ def test_prepare_for_training_with_spacy(self, request, records): docs_test = list(train.get_docs(nlp.vocab)) assert len(list(docs_train)) == 1 assert len(list(docs_test)) == 1 + assert "id" in docs_train[0].user_data + assert "id" in docs_test[0].user_data @pytest.mark.parametrize( "records", @@ -648,6 +650,8 @@ def test_prepare_for_training_with_spacy(self): assert isinstance(test, spacy.tokens.DocBin) assert len(train) == 80 assert len(test) == 20 + assert "id" in train[0].user_data + assert "id" in test[0].user_data @pytest.mark.skipif( _HF_HUB_ACCESS_TOKEN is None, @@ -690,7 +694,7 @@ def test_prepare_for_training(self): r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] train = rb_dataset.prepare_for_training() - assert (set(train.column_names)) == set(["tokens", "ner_tags"]) + assert (set(train.column_names)) == set(["id", "tokens", "ner_tags"]) assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(train, datasets.Dataset) assert "ner_tags" in train.column_names @@ -857,7 +861,7 @@ def test_prepare_for_training(self): train = ds.prepare_for_training(train_size=1, seed=42) assert isinstance(train, datasets.Dataset) - assert train.column_names == ["text", "target"] + assert set(train.column_names) == set(["id", "text", "target"]) assert len(train) == 10 assert train[1]["text"] == "mock" assert train[1]["target"] == "mock" @@ -868,7 +872,7 @@ def test_prepare_for_training(self): assert len(train_test["train"]) == 5 assert len(train_test["test"]) == 5 for split in ["train", "test"]: - assert train_test[split].column_names == ["text", "target"] + assert set(train_test[split].column_names) == set(["id", "text", "target"]) def test_prepare_for_training_with_spacy(self): ds = rg.DatasetForText2Text(