Skip to content

Commit

Permalink
fix: resolved #2483
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Mar 14, 2023
1 parent 5ff0a48 commit b05b4ca
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,13 +720,14 @@ def _prepare_for_training_with_transformers(

inputs_keys = {key: None for rec in self._records for key in rec.inputs if rec.annotation is not None}.keys()

ds_dict = {**{key: [] for key in inputs_keys}, "label": []}
ds_dict = {"id": [], **{key: [] for key in inputs_keys}, "label": []}
for rec in self._records:
if rec.annotation is None:
continue
for key in inputs_keys:
ds_dict[key].append(rec.inputs.get(key))
ds_dict["label"].append(rec.annotation)
ds_dict["id"].append(str(rec.id))

if self._records[0].multi_label:
labels = {label: None for labels in ds_dict["label"] for label in labels}
Expand All @@ -741,6 +742,7 @@ def _prepare_for_training_with_transformers(
)

feature_dict = {
"id": datasets.Value("string"),
**{key: datasets.Value("string") for key in inputs_keys},
"label": [class_label] if self._records[0].multi_label else class_label,
}
Expand All @@ -757,6 +759,7 @@ def _prepare_for_training_with_transformers(
feature_dict["binarized_label"] = feature_dict["label"]
ds = datasets.Dataset.from_dict(
{
"id": ds["id"],
"text": ds["text"],
"context": ds_dict["context"],
"label": labels,
Expand Down Expand Up @@ -787,6 +790,7 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[
else:
text = record.text
doc = nlp.make_doc(text)
doc.user_data["id"] = record.id

cats = dict.fromkeys(all_labels, 0)

Expand Down Expand Up @@ -990,7 +994,7 @@ def spans2iob(example):
new_features = ds.features.copy()
new_features["ner_tags"] = datasets.Sequence(feature=class_tags)
ds = ds.cast(new_features)
ds = ds.remove_columns(set(ds.column_names) - set(["tokens", "ner_tags"]))
ds = ds.remove_columns(set(ds.column_names) - set(["id", "tokens", "ner_tags"]))

if test_size is not None and test_size != 0:
ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)
Expand All @@ -1009,6 +1013,7 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[
continue

doc = nlp.make_doc(record.text)
doc.user_data["id"] = record.id
entities = []

for anno in record.annotation:
Expand Down Expand Up @@ -1250,14 +1255,16 @@ def _prepare_for_training_with_transformers(
):
import datasets

ds_dict = {"text": [], "target": []}
ds_dict = {"id": [], "text": [], "target": []}
for rec in self._records:
if rec.annotation is None:
continue
ds_dict["id"].append(rec.id)
ds_dict["text"].append(rec.text)
ds_dict["target"].append(rec.annotation)

feature_dict = {
"id": datasets.Value("string"),
"text": datasets.Value("string"),
"target": datasets.Value("string"),
}
Expand Down
16 changes: 10 additions & 6 deletions tests/client/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def test_prepare_for_training(self, request, records):
train = ds.prepare_for_training(seed=42)

if not ds[0].multi_label:
column_names = ["text", "context", "label"]
column_names = ["id", "text", "context", "label"]
else:
column_names = ["text", "context", "label", "binarized_label"]
column_names = ["id", "text", "context", "label", "binarized_label"]

assert isinstance(train, datasets.Dataset)
assert train.column_names == column_names
Expand All @@ -361,7 +361,7 @@ def test_prepare_for_training(self, request, records):
assert len(train_test["train"]) == 1
assert len(train_test["test"]) == 1
for split in ["train", "test"]:
assert train_test[split].column_names == column_names
assert set(train_test[split].column_names) == set(column_names)

@pytest.mark.parametrize(
"records",
Expand Down Expand Up @@ -393,6 +393,8 @@ def test_prepare_for_training_with_spacy(self, request, records):
docs_test = list(train.get_docs(nlp.vocab))
assert len(list(docs_train)) == 1
assert len(list(docs_test)) == 1
assert "id" in docs_train[0].user_data
assert "id" in docs_test[0].user_data

@pytest.mark.parametrize(
"records",
Expand Down Expand Up @@ -648,6 +650,8 @@ def test_prepare_for_training_with_spacy(self):
assert isinstance(test, spacy.tokens.DocBin)
assert len(train) == 80
assert len(test) == 20
assert "id" in train[0].user_data
assert "id" in test[0].user_data

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
Expand Down Expand Up @@ -690,7 +694,7 @@ def test_prepare_for_training(self):
r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]

train = rb_dataset.prepare_for_training()
assert (set(train.column_names)) == set(["tokens", "ner_tags"])
assert (set(train.column_names)) == set(["id", "tokens", "ner_tags"])

assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(train, datasets.Dataset)
assert "ner_tags" in train.column_names
Expand Down Expand Up @@ -857,7 +861,7 @@ def test_prepare_for_training(self):
train = ds.prepare_for_training(train_size=1, seed=42)

assert isinstance(train, datasets.Dataset)
assert train.column_names == ["text", "target"]
assert set(train.column_names) == set(["id", "text", "target"])
assert len(train) == 10
assert train[1]["text"] == "mock"
assert train[1]["target"] == "mock"
Expand All @@ -868,7 +872,7 @@ def test_prepare_for_training(self):
assert len(train_test["train"]) == 5
assert len(train_test["test"]) == 5
for split in ["train", "test"]:
assert train_test[split].column_names == ["text", "target"]
assert set(train_test[split].column_names) == set(["id", "text", "target"])

def test_prepare_for_training_with_spacy(self):
ds = rg.DatasetForText2Text(
Expand Down

0 comments on commit b05b4ca

Please sign in to comment.