Skip to content

Commit

Permalink
feat: Allow extending the load_dataset parameters in custom tasks inh…
Browse files Browse the repository at this point in the history
…eriting AbsTask (#299)

* Allow extending the load_dataset parameters

* format

* Fix test

* remove duplicated logic from AbsTask, now handled in the metadata

* add tests

* remove comments, moved to PR

* format

* extend metadata dict from super class

* Remove additional load_data

* test: adding very high level test

* Remove hf_hub_name and add test

* Fix revision in output file

---------

Co-authored-by: gbmarc1 <[email protected]>
  • Loading branch information
gariepyalex and gbmarc1 authored Apr 2, 2024
1 parent 4be555b commit 953780d
Show file tree
Hide file tree
Showing 172 changed files with 1,155 additions and 1,015 deletions.
41 changes: 15 additions & 26 deletions docs/adding_a_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ To add a new dataset to MTEB, you need to do three things:
2) Add metadata to the task
3) Submit the edits to the [MTEB](https://github.com/embeddings-benchmark/mteb) repository

If you have any questions regarding this process feel free to open a discussion [thread](https://github.com/embeddings-benchmark/mteb/discussions).
If you have any questions regarding this process feel free to open a discussion [thread](https://github.com/embeddings-benchmark/mteb/discussions).

> Note: When we mention adding a dataset we refer to a subclass of one of the abstasks.
Expand All @@ -28,13 +28,15 @@ class SciDocsReranking(AbsTaskReranking):
name="SciDocsRR",
description="Ranking of related scientific papers based on their title.",
reference="https://allenai.org/data/scidocs",
hf_hub_name="mteb/scidocs-reranking",
type="Reranking",
category="s2s",
eval_splits=["test"],
eval_langs=["en"],
main_score="map",
revision="d3c5e1fc0b855ab6097bf1cda04dd73947d7caab",
dataset={
"path": "mteb/scidocs-reranking",
"revision": "d3c5e1fc0b855ab6097bf1cda04dd73947d7caab",
}
date=None,
form="written",
domains=["Academic", "Non-fiction"],
Expand All @@ -55,27 +57,29 @@ evaluation = MTEB(tasks=[MindSmallReranking()])
evaluation.run(model)
```

> **Note:** for multilingual tasks, make sure your class also inherits from the `MultilingualTask` class like in [this](https://github.com/embeddings-benchmark/mteb-draft/blob/main/mteb/tasks/Classification/MTOPIntentClassification.py) example.
> **Note:** for multilingual tasks, make sure your class also inherits from the `MultilingualTask` class like in [this](https://github.com/embeddings-benchmark/mteb-draft/blob/main/mteb/tasks/Classification/MTOPIntentClassification.py) example.
> For cross-lingual tasks, make sure your class also inherits from the `CrosslingualTask` class like in [this](https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/BitextMining/TatoebaBitextMining.py).


### A Detailed Example
Often the dataset from HuggingFace is not in the format expected by MTEB. To resolve this you can either change the format on Hugging Face or add a `dataset_transform` method to your dataset to transform it into the right format on the fly. Here is an example along with some design considerations:
Often the dataset from HuggingFace is not in the format expected by MTEB. To resolve this you can either change the format on Hugging Face or add a `dataset_transform` method to your dataset to transform it into the right format on the fly. Here is an example along with some design considerations:

```python
class VGClustering(AbsTaskClustering):
metadata = TaskMetadata(
name="VGClustering",
description="Articles and their classes (e.g. sports) from VG news articles extracted from Norsk Aviskorpus.",
reference="https://huggingface.co/datasets/navjordj/VG_summarization",
hf_hub_name="navjordj/VG_summarization",
type="Clustering",
category="p2p",
eval_splits=["test"],
eval_langs=["nb"],
main_score="v_measure",
revision="d4c5a8ba10ae71224752c727094ac4c46947fa29",
dataset={
"path": "navjordj/VG_summarization",
"revision": "d4c5a8ba10ae71224752c727094ac4c46947fa29",
},
date=("2012-01-01", "2020-01-01"),
form="written",
domains=["Academic", "Non-fiction"],
Expand All @@ -88,21 +92,6 @@ class VGClustering(AbsTaskClustering):
bibtex_citation= ... # removed for brevity
)

def load_data(self, **kwargs: dict): # noqa: ARG002
"""
Load dataset from HuggingFace hub
"""
if self.data_loaded:
return

self.dataset: datasets.DatasetDict = datasets.load_dataset(
self.description["hf_hub_name"],
revision=self.description.get("revision"),
)

self.dataset_transform()
self.data_loaded = True

def dataset_transform(self):
splits = self.description["eval_splits"]

Expand Down Expand Up @@ -136,7 +125,7 @@ class VGClustering(AbsTaskClustering):
labels_batched = list(batched(labels, 512))

# reduce the size of the dataset as we see that we obtain a consistent scores (if we change the seed) even
# with only 512x4 samples.
# with only 512x4 samples.
documents_batched = documents_batched[:4]
labels_batched = labels_batched[:4]

Expand Down Expand Up @@ -227,7 +216,7 @@ These domains subtypes were introduced in the [Scandinavian Embedding Benchmark]

Once you are finished create a PR to the [MTEB](https://github.com/embeddings-benchmark/mteb) repository. If you haven't created a PR before please refer to the [GitHub documentation](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/)

The PR will be reviewed by one of the organizers or contributors who might ask you to change things. Once the PR is approved the dataset will be added into the main repository.
The PR will be reviewed by one of the organizers or contributors who might ask you to change things. Once the PR is approved the dataset will be added into the main repository.


Before you commit here is a checklist you should consider completing before submitting:
Expand All @@ -251,5 +240,5 @@ evaluation = MTEB(tasks=[YourNewTask()])
- [ ] `intfloat/multilingual-e5-small`
- [ ] I have checked that the performance is neither trivial (both models gain close to perfect scores) nor random (both models gain close to random scores).
- [ ] I have considered the size of the dataset and reduced it if it is too big (2048 examples is typically large enough for most tasks)
- [ ] Run tests locally to make sure nothing is broken using `make test`.
- [ ] Run the formatter to format the code using `make lint`.
- [ ] Run tests locally to make sure nothing is broken using `make test`.
- [ ] Run the formatter to format the code using `make lint`.
18 changes: 11 additions & 7 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,27 @@ def __init__(self, seed=42, **kwargs):
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)

def dataset_transform(self):
"""
Transform operations applied to the dataset after loading.
Override this method if your dataset requires any transformation.
"""
pass

def load_data(self, **kwargs):
"""
Load dataset from HuggingFace hub
"""
if self.data_loaded:
return

# TODO: add split argument
self.dataset = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
revision=self.metadata_dict.get("revision", None),
)
self.dataset = datasets.load_dataset(**self.metadata_dict["dataset"])
self.dataset_transform()
self.data_loaded = True

@property
def metadata_dict(self) -> dict[str, str]:
return dict(self.metadata)
metadata_dict = dict(self.metadata)
return metadata_dict

@abstractmethod
def evaluate(self, model, split="test"):
Expand Down
7 changes: 3 additions & 4 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,13 @@ def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = {}, {}, {}
dataset_path = self.metadata_dict["dataset"]["path"]
hf_repo_qrels = (
self.metadata_dict["hf_hub_name"] + "-qrels"
if "clarin-knext" in self.metadata_dict["hf_hub_name"]
else None
dataset_path + "-qrels" if "clarin-knext" in dataset_path else None
)
for split in kwargs.get("eval_splits", self.metadata_dict["eval_splits"]):
corpus, queries, qrels = HFDataLoader(
hf_repo=self.metadata_dict["hf_hub_name"],
hf_repo=dataset_path,
hf_repo_qrels=hf_repo_qrels,
streaming=False,
keep_in_memory=False,
Expand Down
4 changes: 1 addition & 3 deletions mteb/abstasks/CrosslingualTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def load_data(self, **kwargs):
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
lang,
revision=self.metadata_dict.get("revision", None),
name=lang, **self.metadata_dict["dataset"]
)
self.data_loaded = True
5 changes: 2 additions & 3 deletions mteb/abstasks/MultilingualTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def load_data(self, **kwargs):
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
lang,
revision=self.metadata_dict.get("revision", None),
name=lang,
**self.metadata_dict.get("dataset", None),
)
self.data_loaded = True
32 changes: 28 additions & 4 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import logging
from datetime import date

from pydantic import (
AnyUrl,
BaseModel,
BeforeValidator,
TypeAdapter,
field_validator,
model_validator,
)
from typing_extensions import Annotated, Literal

Expand Down Expand Up @@ -91,13 +94,15 @@
SPLIT_NAME = str


logger = logging.getLogger(__name__)


class TaskMetadata(BaseModel):
"""
Metadata for a task.
Args:
hf_hub_name: The name of the dataset for the task on the Hugging Face Hub.
revision: The revision of the dataset for the task on the Hugging Face Hub.
dataset: All arguments to pass to datasets.load_dataset to load the dataset for the task. Refer to https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/loading_methods#datasets.load_dataset
name: The name of the task.
description: A description of the task.
type: The type of the task. These includes "Classification", "Summarization", "STS", "Retrieval", "Reranking", "Clustering",
Expand All @@ -124,8 +129,7 @@ class TaskMetadata(BaseModel):
avg_character_length: The average character length of the samples in the dataset. This should only be for the splits evaluated on.
"""

hf_hub_name: str
revision: str
dataset: dict

name: str
description: str
Expand All @@ -152,3 +156,23 @@ class TaskMetadata(BaseModel):

n_samples: dict[SPLIT_NAME, int] | None
avg_character_length: dict[SPLIT_NAME, float] | None

@field_validator("dataset")
def _check_dataset_path_is_specified(cls, dataset):
"""
This method checks that the dataset path is specified.
"""
if "path" not in dataset or dataset["path"] is None:
raise ValueError(
"You must specify the path to the dataset in the dataset dictionary. "
"See https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset"
)
return dataset

@field_validator("dataset")
def _check_dataset_revision_is_specified(cls, dataset):
if "revision" not in dataset:
raise ValueError("You must explicitly specify a revision for the dataset (either a SHA or None).")
if dataset["revision"] is None:
logging.warning("It is encourage to specify a dataset revision for reproducability")
return dataset
2 changes: 1 addition & 1 deletion mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def run(
# run evaluation
task_results = {
"mteb_version": version("mteb"), # noqa: F405
"dataset_revision": task.metadata_dict.get("revision", None),
"dataset_revision": task.metadata_dict["dataset"].get("revision", None),
"mteb_dataset_name": task.metadata_dict["name"],
}
for split in task_eval_splits:
Expand Down
22 changes: 4 additions & 18 deletions mteb/tasks/BitextMining/da/BornholmskBitextMining.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from __future__ import annotations

import datasets

from mteb.abstasks import AbsTaskBitextMining
from mteb.abstasks.TaskMetadata import TaskMetadata


class BornholmBitextMining(AbsTaskBitextMining):
metadata = TaskMetadata(
name="BornholmBitextMining",
hf_hub_name="strombergnlp/bornholmsk_parallel",
dataset={
"path": "strombergnlp/bornholmsk_parallel",
"revision": "3bc5cfb4ec514264fe2db5615fac9016f7251552",
},
description="Danish Bornholmsk Parallel Corpus. Bornholmsk is a Danish dialect spoken on the island of Bornholm, Denmark. Historically it is a part of east Danish which was also spoken in Scania and Halland, Sweden.",
reference="https://aclanthology.org/W19-6138/",
type="BitextMining",
category="s2s",
eval_splits=["test"],
eval_langs=["da", "da-bornholm"],
main_score="f1",
revision="3bc5cfb4ec514264fe2db5615fac9016f7251552",
date=None,
form=None,
domains=None,
Expand All @@ -32,20 +32,6 @@ class BornholmBitextMining(AbsTaskBitextMining):
n_samples={"test": 500},
)

def load_data(self, **kwargs):
"""
Load dataset from HuggingFace hub and convert it to the standard format.
"""
if self.data_loaded:
return

self.dataset = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
revision=self.metadata_dict.get("revision", None),
)
self.dataset_transform()
self.data_loaded = True

def dataset_transform(self):
# Convert to standard format
self.dataset = self.dataset.rename_column("da", "sentence1")
Expand Down
6 changes: 4 additions & 2 deletions mteb/tasks/BitextMining/multilingual/BUCCBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
class BUCCBitextMining(AbsTaskBitextMining, CrosslingualTask):
metadata = TaskMetadata(
name="BUCC",
hf_hub_name="mteb/bucc-bitext-mining",
dataset={
"path": "mteb/bucc-bitext-mining",
"revision": "d51519689f32196a32af33b075a01d0e7c51e252",
},
description="BUCC bitext mining dataset",
reference="https://comparable.limsi.fr/bucc2018/bucc2018-task.html",
type="BitextMining",
category="s2s",
eval_splits=["test"],
eval_langs=_LANGUAGES,
main_score="f1",
revision="d51519689f32196a32af33b075a01d0e7c51e252",
date=None,
form=None,
domains=None,
Expand Down
12 changes: 6 additions & 6 deletions mteb/tasks/BitextMining/multilingual/DiaBLaBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
class DiaBLaBitextMining(AbsTaskBitextMining, CrosslingualTask):
metadata = TaskMetadata(
name="DiaBlaBitextMining",
hf_hub_name="rbawden/DiaBLa",
dataset={
"path": "rbawden/DiaBLa",
"revision": "5345895c56a601afe1a98519ce3199be60a27dba",
},
description="English-French Parallel Corpus. DiaBLa is an English-French dataset for the evaluation of Machine Translation (MT) for informal, written bilingual dialogue.",
reference="https://inria.hal.science/hal-03021633",
type="BitextMining",
category="s2s",
eval_splits=["test"],
eval_langs=["fr-en", "en-fr"],
main_score="f1",
revision="5345895c56a601afe1a98519ce3199be60a27dba",
date=None,
form=None,
domains=None,
Expand All @@ -41,11 +43,9 @@ def load_data(self, **kwargs):
return

self.dataset = {}

for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
revision=self.metadata_dict.get("revision", None),
)
self.dataset[lang] = datasets.load_dataset(**self.metadata_dict["dataset"])

self.dataset_transform()
self.data_loaded = True
Expand Down
11 changes: 6 additions & 5 deletions mteb/tasks/BitextMining/multilingual/FloresBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,17 @@ def extend_lang_pairs():
class FloresBitextMining(AbsTaskBitextMining, CrosslingualTask):
metadata = TaskMetadata(
name="FloresBitextMining",
hf_hub_name="facebook/flores",
dataset={
"path": "facebook/flores",
"revision": "80dc3040d19756742c9a18267ab30f54fb8e226b",
},
description="FLORES is a benchmark dataset for machine translation between English and low-resource languages.",
reference="https://huggingface.co/datasets/facebook/flores",
type="BitextMining",
category="s2s",
eval_splits=_SPLIT,
eval_langs=_LANGUAGES_PAIRS,
main_score="f1",
revision="80dc3040d19756742c9a18267ab30f54fb8e226b",
date=None,
form=None,
domains=None,
Expand All @@ -267,9 +269,8 @@ def load_data(self, **kwargs):
self.dataset = {}
for lang in self.langs:
self.dataset[lang] = datasets.load_dataset(
self.metadata_dict["hf_hub_name"],
lang,
revision=self.metadata_dict.get("revision", None),
name=lang,
**self.metadata_dict["dataset"],
)
self.dataset_transform()
self.data_loaded = True
Expand Down
Loading

0 comments on commit 953780d

Please sign in to comment.