Skip to content

Commit

Permalink
feat: add CUREv1 retrieval dataset (#1459)
Browse files Browse the repository at this point in the history
* feat: add CUREv1 dataset

---------

Co-authored-by: nadshe <[email protected]>
Co-authored-by: olivierr42 <[email protected]>
Co-authored-by: Daniel Buades Marcos <[email protected]>

* feat: add missing domains to medical tasks

* feat: modify benchmark tasks

* chore: benchmark naming

---------

Co-authored-by: nadshe <[email protected]>
Co-authored-by: olivierr42 <[email protected]>
  • Loading branch information
3 people authored Nov 21, 2024
1 parent 7186e04 commit 1cc6c9e
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 7 deletions.
2 changes: 2 additions & 0 deletions mteb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MTEB_ENG_CLASSIC,
MTEB_MAIN_RU,
MTEB_RETRIEVAL_LAW,
MTEB_RETRIEVAL_MEDICAL,
MTEB_RETRIEVAL_WITH_INSTRUCTIONS,
CoIR,
)
Expand All @@ -24,6 +25,7 @@
"MTEB_ENG_CLASSIC",
"MTEB_MAIN_RU",
"MTEB_RETRIEVAL_LAW",
"MTEB_RETRIEVAL_MEDICAL",
"MTEB_RETRIEVAL_WITH_INSTRUCTIONS",
"CoIR",
"TASKS_REGISTRY",
Expand Down
24 changes: 24 additions & 0 deletions mteb/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,29 @@ def load_results(
citation=None,
)

MTEB_RETRIEVAL_MEDICAL = Benchmark(
name="MTEB(Medical)",
tasks=get_tasks(
tasks=[
"CUREv1",
"NFCorpus",
"TRECCOVID",
"TRECCOVID-PL",
"SciFact",
"SciFact-PL",
"MedicalQARetrieval",
"PublicHealthQA",
"MedrxivClusteringP2P.v2",
"MedrxivClusteringS2S.v2",
"CmedqaRetrieval",
"CMedQAv2-reranking",
],
),
description="A curated set of MTEB tasks designed to evaluate systems in the context of medical information retrieval.",
reference="",
citation=None,
)

MTEB_MINERS_BITEXT_MINING = Benchmark(
name="MINERSBitextMining",
tasks=get_tasks(
Expand Down Expand Up @@ -702,6 +725,7 @@ def load_results(
"SpartQA",
"TempReasonL1",
"TRECCOVID",
"CUREv1",
"WinoGrande",
"BelebeleRetrieval",
"MLQARetrieval",
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Reranking/zho/CMTEBReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CMedQAv2(AbsTaskReranking):
main_score="map",
date=None,
form=None,
domains=None,
domains=["Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/Retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .multilingual.BelebeleRetrieval import *
from .multilingual.CrossLingualSemanticDiscriminationWMT19 import *
from .multilingual.CrossLingualSemanticDiscriminationWMT21 import *
from .multilingual.CUREv1Retrieval import *
from .multilingual.IndicQARetrieval import *
from .multilingual.MintakaRetrieval import *
from .multilingual.MIRACLRetrieval import *
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class NFCorpus(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Academic", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/SciFactRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SciFact(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Academic", "Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TRECCOVID(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Academic", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
151 changes: 151 additions & 0 deletions mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from enum import Enum

from datasets import DatasetDict, load_dataset

from mteb.abstasks.TaskMetadata import TaskMetadata

from ....abstasks.AbsTaskRetrieval import AbsTaskRetrieval
from ....abstasks.MultilingualTask import MultilingualTask

_LANGUAGES = {
"en": ["eng-Latn", "eng-Latn"],
"es": ["spa-Latn", "eng-Latn"],
"fr": ["fra-Latn", "eng-Latn"],
}


class CUREv1Splits(str, Enum):
all = "All"
dentistry_and_oral_health = "Dentistry and Oral Health"
dermatology = "Dermatology"
gastroenterology = "Gastroenterology"
genetics = "Genetics"
neuroscience_and_neurology = "Neuroscience and Neurology"
orthopedic_surgery = "Orthopedic Surgery"
otorhinolaryngology = "Otorhinolaryngology"
plastic_surgery = "Plastic Surgery"
psychiatry_and_psychology = "Psychiatry and Psychology"
pulmonology = "Pulmonology"

@classmethod
def names(cls) -> list[str]:
return sorted(cls._member_names_)


class CUREv1Retrieval(MultilingualTask, AbsTaskRetrieval):
metadata = TaskMetadata(
dataset={
"path": "clinia/CUREv1",
"revision": "3bcf51c91e04d04a8a3329dfbe988b964c5cbe83",
},
name="CUREv1",
description="Collection of query-passage pairs curated by medical professionals, across 10 disciplines and 3 cross-lingual settings.",
type="Retrieval",
modalities=["text"],
category="s2p",
reference="https://huggingface.co/datasets/clinia/CUREv1",
eval_splits=CUREv1Splits.names(),
eval_langs=_LANGUAGES,
main_score="ndcg_at_10",
date=("2024-01-01", "2024-10-31"),
domains=["Medical", "Academic", "Written"],
task_subtypes=[],
license="cc-by-nc-4.0",
annotations_creators="expert-annotated",
dialect=[],
sample_creation="created",
bibtex_citation="",
prompt={
"query": "Given a question by a medical professional, retrieve relevant passages that best answer the question",
},
)

def _load_corpus(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="corpus",
split=split,
cache_dir=cache_dir,
)

corpus = {
doc["_id"]: {"title": doc["title"], "text": doc["text"]} for doc in ds
}

return corpus

def _load_qrels(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="qrels",
split=split,
cache_dir=cache_dir,
)

qrels = {}

for qrel in ds:
query_id = qrel["query-id"]
doc_id = qrel["corpus-id"]
score = int(qrel["score"])
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][doc_id] = score

return qrels

def _load_queries(self, split: str, language: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name=f"queries-{language}",
split=split,
cache_dir=cache_dir,
)

queries = {query["_id"]: query["text"] for query in ds}

return queries

def load_data(self, **kwargs):
if self.data_loaded:
return

eval_splits = kwargs.get("eval_splits", self.metadata.eval_splits)
languages = kwargs.get("eval_langs", self.metadata.eval_langs)
cache_dir = kwargs.get("cache_dir", None)

# Iterate over splits and languages
corpus = {
language: {split: None for split in eval_splits} for language in languages
}
queries = {
language: {split: None for split in eval_splits} for language in languages
}
relevant_docs = {
language: {split: None for split in eval_splits} for language in languages
}
for split in eval_splits:
# Since this is a cross-lingual dataset, the corpus and the relevant documents do not depend on the language
split_corpus = self._load_corpus(split=split, cache_dir=cache_dir)
split_qrels = self._load_qrels(split=split, cache_dir=cache_dir)

# Queries depend on the language
for language in languages:
corpus[language][split] = split_corpus
relevant_docs[language][split] = split_qrels

queries[language][split] = self._load_queries(
split=split, language=language, cache_dir=cache_dir
)

# Convert into DatasetDict
self.corpus = DatasetDict(corpus)
self.queries = DatasetDict(queries)
self.relevant_docs = DatasetDict(relevant_docs)

self.data_loaded = True
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SciFactPL(AbsTaskRetrieval):
eval_langs=["pol-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Academic", "Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TRECCOVIDPL(AbsTaskRetrieval):
"2019-12-01",
"2022-12-31",
), # approximate date of covid pandemic start and end (best guess)
domains=["Academic", "Non-fiction", "Written"],
domains=["Academic", "Medical", "Non-fiction", "Written"],
task_subtypes=["Article retrieval"],
license="not specified",
annotations_creators="derived",
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/zho/CMTEBRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class CmedqaRetrieval(AbsTaskRetrieval):
eval_langs=["cmn-Hans"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down

0 comments on commit 1cc6c9e

Please sign in to comment.