Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable label alignment for token classification datasets #4277

Merged
merged 6 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4408,30 +4408,54 @@ def align_labels_with_mapping(self, label2id: Dict, label_column: str) -> "Datas
raise ValueError(f"Column ({label_column}) not in table columns ({self._data.column_names}).")

label_feature = self.features[label_column]
if not isinstance(label_feature, ClassLabel):
if not (
isinstance(label_feature, ClassLabel)
or (isinstance(label_feature, Sequence) and isinstance(label_feature.feature, ClassLabel))
):
raise ValueError(
f"Aligning labels with a mapping is only supported for {ClassLabel.__name__} column, and column {label_feature} is {type(label_feature).__name__}."
f"Aligning labels with a mapping is only supported for {ClassLabel.__name__} column or {Sequence.__name__} column with the inner type {ClassLabel.__name__}, and column {label_feature} is of type {type(label_feature).__name__}."
)

# Sort input mapping by ID value to ensure the label names are aligned
label2id = dict(sorted(label2id.items(), key=lambda item: item[1]))
label_names = list(label2id.keys())
# Some label mappings use uppercase label names so we lowercase them during alignment
label2id = {k.lower(): v for k, v in label2id.items()}
int2str_function = label_feature.int2str
int2str_function = (
label_feature.int2str if isinstance(label_feature, ClassLabel) else label_feature.feature.int2str
)

def process_label_ids(batch):
dset_label_names = [
int2str_function(label_id).lower() if label_id is not None else None
for label_id in batch[label_column]
]
batch[label_column] = [
label2id[label_name] if label_name is not None else None for label_name in dset_label_names
]
return batch
if isinstance(label_feature, ClassLabel):

def process_label_ids(batch):
dset_label_names = [
int2str_function(label_id).lower() if label_id is not None else None
for label_id in batch[label_column]
]
batch[label_column] = [
label2id[label_name] if label_name is not None else None for label_name in dset_label_names
]
return batch

else:

def process_label_ids(batch):
dset_label_names = [
[int2str_function(label_id).lower() if label_id is not None else None for label_id in seq]
for seq in batch[label_column]
]
batch[label_column] = [
[label2id[label_name] if label_name is not None else None for label_name in seq]
for seq in dset_label_names
]
return batch

features = self.features.copy()
features[label_column] = ClassLabel(num_classes=len(label_names), names=label_names)
features[label_column] = (
ClassLabel(num_classes=len(label_names), names=label_names)
if isinstance(label_feature, ClassLabel)
else Sequence(ClassLabel(num_classes=len(label_names), names=label_names))
)
return self.map(process_label_ids, features=features, batched=True, desc="Aligning the labels")


Expand Down
30 changes: 29 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,7 +3331,7 @@ def test_task_templates_empty_after_preparation(self):
with dset.prepare_for_task(task="text-classification") as dset:
self.assertIsNone(dset.info.task_templates)

def test_align_labels_with_mapping(self):
def test_align_labels_with_mapping_classification(self):
features = Features(
{
"input_text": Value("string"),
Expand All @@ -3349,6 +3349,34 @@ def test_align_labels_with_mapping(self):
aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]]
self.assertListEqual(expected_label_names, aligned_label_names)

def test_align_labels_with_mapping_ner(self):
features = Features(
{
"input_text": Value("string"),
"input_labels": Sequence(
ClassLabel(
names=[
"b-per",
"i-per",
"o",
]
)
),
}
)
data = {"input_text": [["Optimus", "Prime", "is", "a", "Transformer"]], "input_labels": [[0, 1, 2, 2, 2]]}
label2id = {"B-PER": 2, "I-PER": 1, "O": 0}
id2label = {v: k for k, v in label2id.items()}
expected_labels = [[2, 1, 0, 0, 0]]
expected_label_names = [[id2label[idx] for idx in seq] for seq in expected_labels]
with Dataset.from_dict(data, features=features) as dset:
with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
self.assertListEqual(expected_labels, dset["input_labels"])
aligned_label_names = [
dset.features["input_labels"].feature.int2str(idx) for idx in dset["input_labels"]
]
self.assertListEqual(expected_label_names, aligned_label_names)

def test_concatenate_with_no_task_templates(self):
info = DatasetInfo(task_templates=None)
data = {"text": ["i love transformers!"], "labels": [1]}
Expand Down