Skip to content

Commit

Permalink
support arenahard
Browse files Browse the repository at this point in the history
  • Loading branch information
bittersweet1999 committed Apr 26, 2024
1 parent 004ed79 commit 5f5b1e0
Show file tree
Hide file tree
Showing 11 changed files with 540 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Just like a compass guides us on our journey, OpenCompass will guide you through
## 🚀 What's New <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>

- **\[2024.04.26\]** We supported the evaluation of [ArenaHard](configs/eval_subjective_arena_hard.py) welcome to try!🔥🔥🔥.
- **\[2024.04.22\]** We supported the evaluation of [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py)[LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py), welcome to try! 🔥🔥🔥
- **\[2024.02.29\]** We supported the MT-Bench, AlpacalEval and AlignBench, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)
- **\[2024.01.30\]** We release OpenCompass 2.0. Click [CompassKit](https://github.com/open-compass), [CompassHub](https://hub.opencompass.org.cn/home), and [CompassRank](https://rank.opencompass.org.cn/home) for more information !
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
## 🚀 最新进展 <a><img width="35" height="20" src="https://user-images.githubusercontent.com/12782558/212848161-5e783dd6-11e8-4fe0-bbba-39ffb77730be.png"></a>

- **\[2024.04.26\]** 我们支持了 [ArenaHard评测](configs/eval_subjective_arena_hard.py) 欢迎试用!🔥🔥🔥.
- **\[2024.04.22\]** 我们支持了 [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py)[LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py) 的评测,欢迎试用!🔥🔥🔥.
- **\[2024.02.29\]** 我们支持了MT-Bench、AlpacalEval和AlignBench,更多信息可以在[这里](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)找到。
- **\[2024.01.30\]** 我们发布了OpenCompass 2.0。更多信息,请访问[CompassKit](https://github.com/open-compass)[CompassHub](https://hub.opencompass.org.cn/home)[CompassRank](https://rank.opencompass.org.cn/home)
Expand Down
72 changes: 72 additions & 0 deletions configs/datasets/subjective/arena_hard/arena_hard_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.datasets import ArenaHardDataset
from mmengine.config import read_base

subjective_reader_cfg = dict(
input_columns=['question'],
output_column='judge',
)

subjective_all_sets = [
"question",
]


subjective_datasets = []

system_prompt = "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.\n\nBegin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.\n\nWhen evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.\n\nThen consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.\n\nThen consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.\n\nAfter providing your explanation, you must output only one of the following choices as your final verdict with a label:\n\n1. Assistant A is significantly better: [[A>>B]]\n2. Assistant A is slightly better: [[A>B]]\n3. Tie, relatively the same: [[A=B]]\n4. Assistant B is slightly better: [[B>A]]\n5. Assistant B is significantly better: [[B>>A]]\n\nExample output: \"My final verdict is tie: [[A=B]]\"."

judge_prompt = "<|User Prompt|>\n{question}\n\n<|The Start of Assistant A's Answer|>\n{prediction}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{prediction2}\n<|The End of Assistant B's Answer|>"


for _name in subjective_all_sets:
subjective_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt="{question}"
),
]),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=4096),
)

subjective_eval_cfg = dict(
evaluator=dict(
type=LMEvaluator,
prompt_template=dict(
type=PromptTemplate,
template=dict(
begin=[
dict(
role='SYSTEM',
fallback_role='HUMAN',
prompt=system_prompt)
],
round=[
dict(
role='HUMAN',
prompt = judge_prompt
),
]),
),
),
pred_role="BOT",
)

subjective_datasets.append(
dict(
abbr=f"{_name}",
type=ArenaHardDataset,
path="./data/subjective/arena_hard",
name=_name,
reader_cfg=subjective_reader_cfg,
infer_cfg=subjective_infer_cfg,
eval_cfg=subjective_eval_cfg
))
104 changes: 104 additions & 0 deletions configs/eval_subjective_arena_hard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from opencompass.models import HuggingFaceCausalLM
from copy import deepcopy
from opencompass.models import TurboMindModel
from mmengine.config import read_base

from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import ArenaHardSummarizer

with read_base():
from .datasets.subjective.arena_hard.arena_hard_scoring import subjective_datasets

api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
]
)

_meta_template = dict(
round=[
dict(role="HUMAN", begin="<|begin_of_text|>user<|end_header_id|>\n\n", end="<|eot_id|>"),
dict(role="BOT", begin="<|begin_of_text|>assistant<|end_header_id|>\n\n", end="<|eot_id|>", generate=True),
],
)

models = [
dict(
type=HuggingFaceCausalLM,
abbr="llama-3-8b-instruct-hf",
path="meta-llama/Meta-Llama-3-8B-Instruct",
model_kwargs=dict(device_map="auto"),
tokenizer_kwargs=dict(
padding_side="left",
truncation_side="left",
use_fast=False,
),
meta_template=_meta_template,
max_out_len=4096,
max_seq_len=2048,
batch_size=8,
run_cfg=dict(num_gpus=1, num_procs=1),
generation_kwargs={"eos_token_id": [128001, 128009]},
batch_padding=True,
)
]

datasets = [*subjective_datasets]

work_dir = 'outputs/arena_hard/'
# -------------Inferen Stage ----------------------------------------


infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000000),
runner=dict(
type=LocalRunner,
max_num_workers=32,
task=dict(type=OpenICLInferTask)),
)

judge_models = [dict(
abbr='GPT4-Turbo',
type=OpenAI,
path='gpt-4-1106-preview',
key='',
meta_template=api_meta_template,
query_per_second=1,
max_out_len=1024,
max_seq_len=4096,
batch_size=10,
retry=10,
temperature = 0,
)]

## ------------- Evaluation Configuration
gpt4_0314 = dict(
abbr='gpt4-0314',
type=OpenAI,
)

eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner,
max_task_size=1000000,
mode='m2n',
infer_order='double',
base_models=[gpt4_0314],
compare_models=models,
judge_models=judge_models,
),
runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
given_pred = [{'abbr':'gpt4-0314', 'path':''}]
)

summarizer = dict(
type=ArenaHardSummarizer
)
5 changes: 3 additions & 2 deletions docs/en/advanced_guides/subjective_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of

## Current Supported Subjective Evaluation Datasets

1. AlginBench (https://github.com/THUDM/AlignBench)
1. AlignBench (https://github.com/THUDM/AlignBench)
2. MTBench (https://github.com/lm-sys/FastChat)
3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval)
4. CompassArena (Internal dataset)
4. ArenaHard (https://github.com/lm-sys/arena-hard/tree/main)
5. CompassArena (Internal dataset)

## Subjective Evaluation with Custom Dataset

Expand Down
5 changes: 3 additions & 2 deletions docs/zh_cn/advanced_guides/subjective_evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

## 目前已支持的主观评测数据集

1. AlginBenchhttps://github.com/THUDM/AlignBench)
1. AlignBenchhttps://github.com/THUDM/AlignBench)
2. MTBench (https://github.com/lm-sys/FastChat)
3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval)
4. CompassArena(内部数据集)
4. ArenaHard (https://github.com/lm-sys/arena-hard/tree/main)
5. CompassArena(内部数据集)

## 自定义主观数据集评测

Expand Down
1 change: 1 addition & 0 deletions opencompass/datasets/subjective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .alignbench import AlignmentBenchDataset # noqa: F401, F403
from .arena_hard import ArenaHardDataset # noqa: F401, F403
from .compass_arena import CompassArenaDataset # noqa: F401, F403
from .corev2 import Corev2Dataset # noqa: F401, F403
from .creationbench import CreationBenchDataset # noqa: F401, F403
Expand Down
34 changes: 34 additions & 0 deletions opencompass/datasets/subjective/arena_hard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
import os.path as osp

from datasets import Dataset, DatasetDict

from opencompass.registry import LOAD_DATASET

from ..base import BaseDataset


@LOAD_DATASET.register_module()
class ArenaHardDataset(BaseDataset):

def load(self, path: str, name: str):
filename = osp.join(path, f'{name}.jsonl')
dataset = DatasetDict()
raw_data = []
with open(filename, 'r', encoding='utf-8') as file:
for line in file:
problem = json.loads(line)
question_id = problem['question_id']
category = problem['category']
cluster = problem['cluster']
question = problem['turns'][0]['content'] # only one turn in arena_hard
raw_data.append({
'question': question,
'capability': cluster,
'judge': {
'capability': cluster,
'question': question
}
})
dataset = Dataset.from_list(raw_data)
return dataset
1 change: 0 additions & 1 deletion opencompass/runners/slurm_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def _launch(self, cfg: ConfigDict, child_conn: Pipe = None):
tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += ' -x HOST-10-140-60-7'
tmpl += f" -N1 -u -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,
Expand Down
1 change: 1 addition & 0 deletions opencompass/summarizers/subjective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa: F401, E501
from .alignmentbench import AlignmentBenchSummarizer
from .alpacaeval import AlpacaSummarizer
from .arenahard import ArenaHardSummarizer
from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer
from .creationbench import CreationBenchSummarizer
Expand Down
Loading

0 comments on commit 5f5b1e0

Please sign in to comment.