diff --git a/mteb-zh/readme.md b/mteb-zh/readme.md index 4882092..364ff43 100644 --- a/mteb-zh/readme.md +++ b/mteb-zh/readme.md @@ -59,38 +59,29 @@ pip install -r requirements.txt ``` 2. 运行评测脚本 ```bash -# model_type: m3e_base, erlangshen, uer, d_meta_soul, openai, text2vec ... -python run_mteb_zh +python run_mteb_zh.py --model-type --model-id ``` 3. 查看帮助 ```bash -python run_mteb_zh --help +python run_mteb_zh.py --help ``` ### 示例 评测 M3E-base 模型 ```bash -python run_mteb_zh sentence-transformer moka-ai/m3e-base +python run_mteb_zh --model-type sentence_transformer --model-id moka-ai/m3e-base ``` 评测 UER 模型 ```bash -python run_mteb_zh uer uer/sbert-base-chinese-nli +python run_mteb_zh --model-type sentence_transformer --model-id uer/sbert-base-chinese-nli ``` 评测 ErLangShen 模型 ```bash -python run_mteb_zh erlangshen +python run_mteb_zh --model-type erlangshen ``` -case ModelType.m3e_small: - return SentenceTransformer('moka-ai/m3e-small') -case ModelType.m3e_base: - return SentenceTransformer('moka-ai/m3e-base') -case ModelType.d_meta_soul: - return SentenceTransformer('DMetaSoul/sbert-chinese-general-v2') -case ModelType.uer: - return SentenceTransformer('uer/sbert-base-chinese-nli') ## 评测你的模型 diff --git a/mteb-zh/run_mteb_zh.py b/mteb-zh/run_mteb_zh.py index ff971ec..33fd34b 100644 --- a/mteb-zh/run_mteb_zh.py +++ b/mteb-zh/run_mteb_zh.py @@ -30,24 +30,34 @@ ] +def filter_by_name(name: str): + return [task for task in default_tasks if task.description['name'] == name] # type: ignore + + +def filter_by_type(task_type: TaskType): + if task_type is TaskType.All: + return default_tasks + else: + return [task for task in default_tasks if task.description['type'] == task_type.value] # type: ignore + + def main( model_type: Annotated[ModelType, typer.Option()], model_id: str | None = None, - model_name: str | None = None, task_type: TaskType = TaskType.Classification, + task_name: str | None = None, output_folder: Path = Path('results'), ): output_folder = Path(output_folder) model = load_model(model_type, model_id) - if task_type is TaskType.All: - tasks = default_tasks + if task_name: + tasks = filter_by_name(task_name) else: - tasks = [task for task in default_tasks if task.description['type'] == task_type.value] # type: ignore + tasks = filter_by_type(task_type) evaluation = MTEB(tasks=tasks) - if model_name is None: - model_name = model_type.value + (f'-{model_id.replace("/", "-")}' if model_id else '') + model_name = model_type.value + (f'-{model_id.replace("/", "-")}' if model_id else '') evaluation.run(model, output_folder=str(output_folder / model_name))