Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CUREv1 retrieval dataset #1459

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mteb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MTEB_ENG_CLASSIC,
MTEB_MAIN_RU,
MTEB_RETRIEVAL_LAW,
MTEB_RETRIEVAL_MEDICAL,
MTEB_RETRIEVAL_WITH_INSTRUCTIONS,
CoIR,
)
Expand All @@ -24,6 +25,7 @@
"MTEB_ENG_CLASSIC",
"MTEB_MAIN_RU",
"MTEB_RETRIEVAL_LAW",
"MTEB_RETRIEVAL_MEDICAL",
"MTEB_RETRIEVAL_WITH_INSTRUCTIONS",
"CoIR",
"TASKS_REGISTRY",
Expand Down
24 changes: 24 additions & 0 deletions mteb/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,29 @@ def load_results(
citation=None,
)

MTEB_RETRIEVAL_MEDICAL = Benchmark(
name="MTEB(Medical)",
tasks=get_tasks(
tasks=[
"CUREv1",
"NFCorpus",
"TRECCOVID",
"TRECCOVID-PL",
"SciFact",
"SciFact-PL",
"MedicalQARetrieval",
"PublicHealthQA",
"MedrxivClusteringP2P.v2",
"MedrxivClusteringS2S.v2",
"CmedqaRetrieval",
"CMedQAv2-reranking",
],
),
description="A curated set of MTEB tasks designed to evaluate systems in the context of medical information retrieval.",
reference="",
citation=None,
)

MTEB_MINERS_BITEXT_MINING = Benchmark(
name="MINERSBitextMining",
tasks=get_tasks(
Expand Down Expand Up @@ -702,6 +725,7 @@ def load_results(
"SpartQA",
"TempReasonL1",
"TRECCOVID",
"CUREv1",
"WinoGrande",
"BelebeleRetrieval",
"MLQARetrieval",
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Reranking/zho/CMTEBReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CMedQAv2(AbsTaskReranking):
main_score="map",
date=None,
form=None,
domains=None,
domains=["Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/Retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .multilingual.BelebeleRetrieval import *
from .multilingual.CrossLingualSemanticDiscriminationWMT19 import *
from .multilingual.CrossLingualSemanticDiscriminationWMT21 import *
from .multilingual.CUREv1Retrieval import *
from .multilingual.IndicQARetrieval import *
from .multilingual.MintakaRetrieval import *
from .multilingual.MIRACLRetrieval import *
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class NFCorpus(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Academic", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/SciFactRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SciFact(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Academic", "Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TRECCOVID(AbsTaskRetrieval):
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Academic", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
151 changes: 151 additions & 0 deletions mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from enum import Enum

from datasets import DatasetDict, load_dataset

from mteb.abstasks.TaskMetadata import TaskMetadata

from ....abstasks.AbsTaskRetrieval import AbsTaskRetrieval
from ....abstasks.MultilingualTask import MultilingualTask

_LANGUAGES = {
"en": ["eng-Latn", "eng-Latn"],
"es": ["spa-Latn", "eng-Latn"],
"fr": ["fra-Latn", "eng-Latn"],
}


class CUREv1Splits(str, Enum):
all = "All"
dentistry_and_oral_health = "Dentistry and Oral Health"
dermatology = "Dermatology"
gastroenterology = "Gastroenterology"
genetics = "Genetics"
neuroscience_and_neurology = "Neuroscience and Neurology"
orthopedic_surgery = "Orthopedic Surgery"
otorhinolaryngology = "Otorhinolaryngology"
plastic_surgery = "Plastic Surgery"
psychiatry_and_psychology = "Psychiatry and Psychology"
pulmonology = "Pulmonology"

@classmethod
def names(cls) -> list[str]:
return sorted(cls._member_names_)


class CUREv1Retrieval(MultilingualTask, AbsTaskRetrieval):
metadata = TaskMetadata(
dataset={
"path": "clinia/CUREv1",
"revision": "3bcf51c91e04d04a8a3329dfbe988b964c5cbe83",
},
name="CUREv1",
description="Collection of query-passage pairs curated by medical professionals, across 10 disciplines and 3 cross-lingual settings.",
type="Retrieval",
modalities=["text"],
category="s2p",
reference="https://huggingface.co/datasets/clinia/CUREv1",
eval_splits=CUREv1Splits.names(),
eval_langs=_LANGUAGES,
main_score="ndcg_at_10",
date=("2024-01-01", "2024-10-31"),
domains=["Medical", "Academic", "Written"],
task_subtypes=[],
license="cc-by-nc-4.0",
annotations_creators="expert-annotated",
dialect=[],
sample_creation="created",
bibtex_citation="",
prompt={
"query": "Given a question by a medical professional, retrieve relevant passages that best answer the question",
},
)

def _load_corpus(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="corpus",
split=split,
cache_dir=cache_dir,
)

corpus = {
doc["_id"]: {"title": doc["title"], "text": doc["text"]} for doc in ds
}

return corpus

def _load_qrels(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="qrels",
split=split,
cache_dir=cache_dir,
)

qrels = {}

for qrel in ds:
query_id = qrel["query-id"]
doc_id = qrel["corpus-id"]
score = int(qrel["score"])
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][doc_id] = score

return qrels

def _load_queries(self, split: str, language: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name=f"queries-{language}",
split=split,
cache_dir=cache_dir,
)

queries = {query["_id"]: query["text"] for query in ds}

return queries

def load_data(self, **kwargs):
if self.data_loaded:
return

eval_splits = kwargs.get("eval_splits", self.metadata.eval_splits)
languages = kwargs.get("eval_langs", self.metadata.eval_langs)
cache_dir = kwargs.get("cache_dir", None)

# Iterate over splits and languages
corpus = {
language: {split: None for split in eval_splits} for language in languages
}
queries = {
language: {split: None for split in eval_splits} for language in languages
}
relevant_docs = {
language: {split: None for split in eval_splits} for language in languages
}
for split in eval_splits:
# Since this is a cross-lingual dataset, the corpus and the relevant documents do not depend on the language
split_corpus = self._load_corpus(split=split, cache_dir=cache_dir)
split_qrels = self._load_qrels(split=split, cache_dir=cache_dir)

# Queries depend on the language
for language in languages:
corpus[language][split] = split_corpus
relevant_docs[language][split] = split_qrels

queries[language][split] = self._load_queries(
split=split, language=language, cache_dir=cache_dir
)

# Convert into DatasetDict
self.corpus = DatasetDict(corpus)
self.queries = DatasetDict(queries)
self.relevant_docs = DatasetDict(relevant_docs)

self.data_loaded = True
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SciFactPL(AbsTaskRetrieval):
eval_langs=["pol-Latn"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Academic", "Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TRECCOVIDPL(AbsTaskRetrieval):
"2019-12-01",
"2022-12-31",
), # approximate date of covid pandemic start and end (best guess)
domains=["Academic", "Non-fiction", "Written"],
domains=["Academic", "Medical", "Non-fiction", "Written"],
task_subtypes=["Article retrieval"],
license="not specified",
annotations_creators="derived",
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/zho/CMTEBRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class CmedqaRetrieval(AbsTaskRetrieval):
eval_langs=["cmn-Hans"],
main_score="ndcg_at_10",
date=None,
domains=None,
domains=["Medical", "Written"],
task_subtypes=None,
license=None,
annotations_creators=None,
Expand Down
Loading