diff --git a/mteb/__init__.py b/mteb/__init__.py index 1ef561a5f..6de017b1f 100644 --- a/mteb/__init__.py +++ b/mteb/__init__.py @@ -6,6 +6,7 @@ MTEB_ENG_CLASSIC, MTEB_MAIN_RU, MTEB_RETRIEVAL_LAW, + MTEB_RETRIEVAL_MEDICAL, MTEB_RETRIEVAL_WITH_INSTRUCTIONS, CoIR, ) @@ -24,6 +25,7 @@ "MTEB_ENG_CLASSIC", "MTEB_MAIN_RU", "MTEB_RETRIEVAL_LAW", + "MTEB_RETRIEVAL_MEDICAL", "MTEB_RETRIEVAL_WITH_INSTRUCTIONS", "CoIR", "TASKS_REGISTRY", diff --git a/mteb/benchmarks/benchmarks.py b/mteb/benchmarks/benchmarks.py index 4b5c53c2c..9aaefda3c 100644 --- a/mteb/benchmarks/benchmarks.py +++ b/mteb/benchmarks/benchmarks.py @@ -308,6 +308,29 @@ def load_results( citation=None, ) +MTEB_RETRIEVAL_MEDICAL = Benchmark( + name="MTEB(Medical)", + tasks=get_tasks( + tasks=[ + "CUREv1", + "NFCorpus", + "TRECCOVID", + "TRECCOVID-PL", + "SciFact", + "SciFact-PL", + "MedicalQARetrieval", + "PublicHealthQA", + "MedrxivClusteringP2P.v2", + "MedrxivClusteringS2S.v2", + "CmedqaRetrieval", + "CMedQAv2-reranking", + ], + ), + description="A curated set of MTEB tasks designed to evaluate systems in the context of medical information retrieval.", + reference="", + citation=None, +) + MTEB_MINERS_BITEXT_MINING = Benchmark( name="MINERSBitextMining", tasks=get_tasks( @@ -702,6 +725,7 @@ def load_results( "SpartQA", "TempReasonL1", "TRECCOVID", + "CUREv1", "WinoGrande", "BelebeleRetrieval", "MLQARetrieval", diff --git a/mteb/tasks/Reranking/zho/CMTEBReranking.py b/mteb/tasks/Reranking/zho/CMTEBReranking.py index 302f62adf..7a33f7ae0 100644 --- a/mteb/tasks/Reranking/zho/CMTEBReranking.py +++ b/mteb/tasks/Reranking/zho/CMTEBReranking.py @@ -128,7 +128,7 @@ class CMedQAv2(AbsTaskReranking): main_score="map", date=None, form=None, - domains=None, + domains=["Medical", "Written"], task_subtypes=None, license=None, annotations_creators=None, diff --git a/mteb/tasks/Retrieval/__init__.py b/mteb/tasks/Retrieval/__init__.py index f8a47b08a..ca41d4354 100644 --- a/mteb/tasks/Retrieval/__init__.py +++ b/mteb/tasks/Retrieval/__init__.py @@ -105,6 +105,7 @@ from .multilingual.BelebeleRetrieval import * from .multilingual.CrossLingualSemanticDiscriminationWMT19 import * from .multilingual.CrossLingualSemanticDiscriminationWMT21 import * +from .multilingual.CUREv1Retrieval import * from .multilingual.IndicQARetrieval import * from .multilingual.MintakaRetrieval import * from .multilingual.MIRACLRetrieval import * diff --git a/mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py b/mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py index 7c40b6707..31f4eb60b 100644 --- a/mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py +++ b/mteb/tasks/Retrieval/eng/NFCorpusRetrieval.py @@ -21,7 +21,7 @@ class NFCorpus(AbsTaskRetrieval): eval_langs=["eng-Latn"], main_score="ndcg_at_10", date=None, - domains=None, + domains=["Medical", "Academic", "Written"], task_subtypes=None, license=None, annotations_creators=None, diff --git a/mteb/tasks/Retrieval/eng/SciFactRetrieval.py b/mteb/tasks/Retrieval/eng/SciFactRetrieval.py index 05e9a6e54..1dc47d8b6 100644 --- a/mteb/tasks/Retrieval/eng/SciFactRetrieval.py +++ b/mteb/tasks/Retrieval/eng/SciFactRetrieval.py @@ -21,7 +21,7 @@ class SciFact(AbsTaskRetrieval): eval_langs=["eng-Latn"], main_score="ndcg_at_10", date=None, - domains=None, + domains=["Academic", "Medical", "Written"], task_subtypes=None, license=None, annotations_creators=None, diff --git a/mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py b/mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py index 6c7b7f01d..00c96c0d0 100644 --- a/mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py +++ b/mteb/tasks/Retrieval/eng/TRECCOVIDRetrieval.py @@ -21,7 +21,7 @@ class TRECCOVID(AbsTaskRetrieval): eval_langs=["eng-Latn"], main_score="ndcg_at_10", date=None, - domains=None, + domains=["Medical", "Academic", "Written"], task_subtypes=None, license=None, annotations_creators=None, diff --git a/mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py b/mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py new file mode 100644 index 000000000..6e97786a7 --- /dev/null +++ b/mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from enum import Enum + +from datasets import DatasetDict, load_dataset + +from mteb.abstasks.TaskMetadata import TaskMetadata + +from ....abstasks.AbsTaskRetrieval import AbsTaskRetrieval +from ....abstasks.MultilingualTask import MultilingualTask + +_LANGUAGES = { + "en": ["eng-Latn", "eng-Latn"], + "es": ["spa-Latn", "eng-Latn"], + "fr": ["fra-Latn", "eng-Latn"], +} + + +class CUREv1Splits(str, Enum): + all = "All" + dentistry_and_oral_health = "Dentistry and Oral Health" + dermatology = "Dermatology" + gastroenterology = "Gastroenterology" + genetics = "Genetics" + neuroscience_and_neurology = "Neuroscience and Neurology" + orthopedic_surgery = "Orthopedic Surgery" + otorhinolaryngology = "Otorhinolaryngology" + plastic_surgery = "Plastic Surgery" + psychiatry_and_psychology = "Psychiatry and Psychology" + pulmonology = "Pulmonology" + + @classmethod + def names(cls) -> list[str]: + return sorted(cls._member_names_) + + +class CUREv1Retrieval(MultilingualTask, AbsTaskRetrieval): + metadata = TaskMetadata( + dataset={ + "path": "clinia/CUREv1", + "revision": "3bcf51c91e04d04a8a3329dfbe988b964c5cbe83", + }, + name="CUREv1", + description="Collection of query-passage pairs curated by medical professionals, across 10 disciplines and 3 cross-lingual settings.", + type="Retrieval", + modalities=["text"], + category="s2p", + reference="https://huggingface.co/datasets/clinia/CUREv1", + eval_splits=CUREv1Splits.names(), + eval_langs=_LANGUAGES, + main_score="ndcg_at_10", + date=("2024-01-01", "2024-10-31"), + domains=["Medical", "Academic", "Written"], + task_subtypes=[], + license="cc-by-nc-4.0", + annotations_creators="expert-annotated", + dialect=[], + sample_creation="created", + bibtex_citation="", + prompt={ + "query": "Given a question by a medical professional, retrieve relevant passages that best answer the question", + }, + ) + + def _load_corpus(self, split: str, cache_dir: str | None = None): + ds = load_dataset( + path=self.metadata_dict["dataset"]["path"], + revision=self.metadata_dict["dataset"]["revision"], + name="corpus", + split=split, + cache_dir=cache_dir, + ) + + corpus = { + doc["_id"]: {"title": doc["title"], "text": doc["text"]} for doc in ds + } + + return corpus + + def _load_qrels(self, split: str, cache_dir: str | None = None): + ds = load_dataset( + path=self.metadata_dict["dataset"]["path"], + revision=self.metadata_dict["dataset"]["revision"], + name="qrels", + split=split, + cache_dir=cache_dir, + ) + + qrels = {} + + for qrel in ds: + query_id = qrel["query-id"] + doc_id = qrel["corpus-id"] + score = int(qrel["score"]) + if query_id not in qrels: + qrels[query_id] = {} + qrels[query_id][doc_id] = score + + return qrels + + def _load_queries(self, split: str, language: str, cache_dir: str | None = None): + ds = load_dataset( + path=self.metadata_dict["dataset"]["path"], + revision=self.metadata_dict["dataset"]["revision"], + name=f"queries-{language}", + split=split, + cache_dir=cache_dir, + ) + + queries = {query["_id"]: query["text"] for query in ds} + + return queries + + def load_data(self, **kwargs): + if self.data_loaded: + return + + eval_splits = kwargs.get("eval_splits", self.metadata.eval_splits) + languages = kwargs.get("eval_langs", self.metadata.eval_langs) + cache_dir = kwargs.get("cache_dir", None) + + # Iterate over splits and languages + corpus = { + language: {split: None for split in eval_splits} for language in languages + } + queries = { + language: {split: None for split in eval_splits} for language in languages + } + relevant_docs = { + language: {split: None for split in eval_splits} for language in languages + } + for split in eval_splits: + # Since this is a cross-lingual dataset, the corpus and the relevant documents do not depend on the language + split_corpus = self._load_corpus(split=split, cache_dir=cache_dir) + split_qrels = self._load_qrels(split=split, cache_dir=cache_dir) + + # Queries depend on the language + for language in languages: + corpus[language][split] = split_corpus + relevant_docs[language][split] = split_qrels + + queries[language][split] = self._load_queries( + split=split, language=language, cache_dir=cache_dir + ) + + # Convert into DatasetDict + self.corpus = DatasetDict(corpus) + self.queries = DatasetDict(queries) + self.relevant_docs = DatasetDict(relevant_docs) + + self.data_loaded = True diff --git a/mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py b/mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py index 2588b1c28..92d61b42b 100644 --- a/mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py +++ b/mteb/tasks/Retrieval/pol/SciFactPLRetrieval.py @@ -22,7 +22,7 @@ class SciFactPL(AbsTaskRetrieval): eval_langs=["pol-Latn"], main_score="ndcg_at_10", date=None, - domains=None, + domains=["Academic", "Medical", "Written"], task_subtypes=None, license=None, annotations_creators=None, diff --git a/mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py b/mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py index 4ba6a9ac0..f9f331191 100644 --- a/mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py +++ b/mteb/tasks/Retrieval/pol/TRECCOVIDPLRetrieval.py @@ -25,7 +25,7 @@ class TRECCOVIDPL(AbsTaskRetrieval): "2019-12-01", "2022-12-31", ), # approximate date of covid pandemic start and end (best guess) - domains=["Academic", "Non-fiction", "Written"], + domains=["Academic", "Medical", "Non-fiction", "Written"], task_subtypes=["Article retrieval"], license="not specified", annotations_creators="derived", diff --git a/mteb/tasks/Retrieval/zho/CMTEBRetrieval.py b/mteb/tasks/Retrieval/zho/CMTEBRetrieval.py index 08674ec8c..ad26652cc 100644 --- a/mteb/tasks/Retrieval/zho/CMTEBRetrieval.py +++ b/mteb/tasks/Retrieval/zho/CMTEBRetrieval.py @@ -236,7 +236,7 @@ class CmedqaRetrieval(AbsTaskRetrieval): eval_langs=["cmn-Hans"], main_score="ndcg_at_10", date=None, - domains=None, + domains=["Medical", "Written"], task_subtypes=None, license=None, annotations_creators=None,